Miscellaneous Ops

The pyro.ops module implements tensor utilities that are mostly independent of the rest of Pyro.

Utilities for HMC

class DualAveraging(prox_center=0, t0=10, kappa=0.75, gamma=0.05)[source]

Bases: object

Dual Averaging is a scheme to solve convex optimization problems. It belongs to a class of subgradient methods which uses subgradients to update parameters (in primal space) of a model. Under some conditions, the averages of generated parameters during the scheme are guaranteed to converge to an optimal value. However, a counter-intuitive aspect of traditional subgradient methods is “new subgradients enter the model with decreasing weights” (see \([1]\)). Dual Averaging scheme solves that phenomenon by updating parameters using weights equally for subgradients (which lie in a dual space), hence we have the name “dual averaging”.

This class implements a dual averaging scheme which is adapted for Markov chain Monte Carlo (MCMC) algorithms. To be more precise, we will replace subgradients by some statistics calculated during an MCMC trajectory. In addition, introducing some free parameters such as t0 and kappa is helpful and still guarantees the convergence of the scheme.

References

[1] Primal-dual subgradient methods for convex problems, Yurii Nesterov

[2] The No-U-turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo, Matthew D. Hoffman, Andrew Gelman

Parameters
  • prox_center (float) – A “prox-center” parameter introduced in \([1]\) which pulls the primal sequence towards it.

  • t0 (float) – A free parameter introduced in \([2]\) that stabilizes the initial steps of the scheme.

  • kappa (float) – A free parameter introduced in \([2]\) that controls the weights of steps of the scheme. For a small kappa, the scheme will quickly forget states from early steps. This should be a number in \((0.5, 1]\).

  • gamma (float) – A free parameter which controls the speed of the convergence of the scheme.

reset()[source]
step(g)[source]

Updates states of the scheme given a new statistic/subgradient g.

Parameters

g (float) – A statistic calculated during an MCMC trajectory or subgradient.

get_state()[source]

Returns the latest \(x_t\) and average of \(\left\{x_i\right\}_{i=1}^t\) in primal space.

velocity_verlet(z, r, potential_fn, kinetic_grad, step_size, num_steps=1, z_grads=None)[source]

Second order symplectic integrator that uses the velocity verlet algorithm.

Parameters
  • z (dict) – dictionary of sample site names and their current values (type Tensor).

  • r (dict) – dictionary of sample site names and corresponding momenta (type Tensor).

  • potential_fn (callable) – function that returns potential energy given z for each sample site. The negative gradient of the function with respect to z determines the rate of change of the corresponding sites’ momenta r.

  • kinetic_grad (callable) – a function calculating gradient of kinetic energy w.r.t. momentum variable.

  • step_size (float) – step size for each time step iteration.

  • num_steps (int) – number of discrete time steps over which to integrate.

  • z_grads (torch.Tensor) – optional gradients of potential energy at current z.

Return tuple (z_next, r_next, z_grads, potential_energy)

next position and momenta, together with the potential energy and its gradient w.r.t. z_next.

potential_grad(potential_fn, z)[source]

Gradient of potential_fn w.r.t. parameters z.

Parameters
  • potential_fn – python callable that takes in a dictionary of parameters and returns the potential energy.

  • z (dict) – dictionary of parameter values keyed by site name.

Returns

tuple of (z_grads, potential_energy), where z_grads is a dictionary with the same keys as z containing gradients and potential_energy is a torch scalar.

register_exception_handler(name: str, handler: Callable[[Exception], bool], warn_on_overwrite: bool = True) None[source]

Register an exception handler for handling (primarily numerical) errors when evaluating the potential function.

Parameters
  • name – name of the handler (must be unique).

  • handler – A callable mapping an Exception to a boolean. Exceptions that evaluate to true in any of the handlers are handled in the computation of the potential energy.

  • warn_on_overwrite – If True, warns when overwriting a handler already registered under the provided name.

class WelfordCovariance(diagonal=True)[source]

Bases: object

Implements Welford’s online scheme for estimating (co)variance (see \([1]\)). Useful for adapting diagonal and dense mass structures for HMC.

References

[1] The Art of Computer Programming, Donald E. Knuth

reset()[source]
update(sample)[source]
get_covariance(regularize=True)[source]
class WelfordArrowheadCovariance(head_size=0)[source]

Bases: object

Likes WelfordCovariance but generalized to the arrowhead structure.

reset()[source]
update(sample)[source]
get_covariance(regularize=True)[source]

Gets the covariance in arrowhead form: (top, bottom_diag) where top = cov[:head_size] and bottom_diag = cov.diag()[head_size:].

Newton Optimizers

newton_step(loss, x, trust_radius=None)[source]

Performs a Newton update step to minimize loss on a batch of variables, optionally constraining to a trust region [1].

This is especially usful because the final solution of newton iteration is differentiable wrt the inputs, even when all but the final x is detached, due to this method’s quadratic convergence [2]. loss must be twice-differentiable as a function of x. If loss is 2+d-times differentiable, then the return value of this function is d-times differentiable.

When loss is interpreted as a negative log probability density, then the return values mode,cov of this function can be used to construct a Laplace approximation MultivariateNormal(mode,cov).

Warning

Take care to detach the result of this function when used in an optimization loop. If you forget to detach the result of this function during optimization, then backprop will propagate through the entire iteration process, and worse will compute two extra derivatives for each step.

Example use inside a loop:

x = torch.zeros(1000, 2)  # arbitrary initial value
for step in range(100):
    x = x.detach()          # block gradients through previous steps
    x.requires_grad = True  # ensure loss is differentiable wrt x
    loss = my_loss_function(x)
    x = newton_step(loss, x, trust_radius=1.0)
# the final x is still differentiable
[1] Yuan, Ya-xiang. Iciam. Vol. 99. 2000.

“A review of trust region algorithms for optimization.” ftp://ftp.cc.ac.cn/pub/yyx/papers/p995.pdf

[2] Christianson, Bruce. Optimization Methods and Software 3.4 (1994)

“Reverse accumulation and attractive fixed points.” http://uhra.herts.ac.uk/bitstream/handle/2299/4338/903839.pdf

Parameters
  • loss (torch.Tensor) – A scalar function of x to be minimized.

  • x (torch.Tensor) – A dependent variable of shape (N, D) where N is the batch size and D is a small number.

  • trust_radius (float) – An optional trust region trust_radius. The updated value mode of this function will be within trust_radius of the input x.

Returns

A pair (mode, cov) where mode is an updated tensor of the same shape as the original value x, and cov is an esitmate of the covariance DxD matrix with cov.shape == x.shape[:-1] + (D,D).

Return type

tuple

newton_step_1d(loss, x, trust_radius=None)[source]

Performs a Newton update step to minimize loss on a batch of 1-dimensional variables, optionally regularizing to constrain to a trust region.

See newton_step() for details.

Parameters
  • loss (torch.Tensor) – A scalar function of x to be minimized.

  • x (torch.Tensor) – A dependent variable with rightmost size of 1.

  • trust_radius (float) – An optional trust region trust_radius. The updated value mode of this function will be within trust_radius of the input x.

Returns

A pair (mode, cov) where mode is an updated tensor of the same shape as the original value x, and cov is an esitmate of the covariance 1x1 matrix with cov.shape == x.shape[:-1] + (1,1).

Return type

tuple

newton_step_2d(loss, x, trust_radius=None)[source]

Performs a Newton update step to minimize loss on a batch of 2-dimensional variables, optionally regularizing to constrain to a trust region.

See newton_step() for details.

Parameters
  • loss (torch.Tensor) – A scalar function of x to be minimized.

  • x (torch.Tensor) – A dependent variable with rightmost size of 2.

  • trust_radius (float) – An optional trust region trust_radius. The updated value mode of this function will be within trust_radius of the input x.

Returns

A pair (mode, cov) where mode is an updated tensor of the same shape as the original value x, and cov is an esitmate of the covariance 2x2 matrix with cov.shape == x.shape[:-1] + (2,2).

Return type

tuple

newton_step_3d(loss, x, trust_radius=None)[source]

Performs a Newton update step to minimize loss on a batch of 3-dimensional variables, optionally regularizing to constrain to a trust region.

See newton_step() for details.

Parameters
  • loss (torch.Tensor) – A scalar function of x to be minimized.

  • x (torch.Tensor) – A dependent variable with rightmost size of 2.

  • trust_radius (float) – An optional trust region trust_radius. The updated value mode of this function will be within trust_radius of the input x.

Returns

A pair (mode, cov) where mode is an updated tensor of the same shape as the original value x, and cov is an esitmate of the covariance 3x3 matrix with cov.shape == x.shape[:-1] + (3,3).

Return type

tuple

Special Functions

safe_log(x)[source]

Like torch.log() but avoids infinite gradients at log(0) by clamping them to at most 1 / finfo.eps.

log_beta(x, y, tol=0.0)[source]

Computes log Beta function.

When tol >= 0.02 this uses a shifted Stirling’s approximation to the log Beta function. The approximation adapts Stirling’s approximation of the log Gamma function:

lgamma(z) ≈ (z - 1/2) * log(z) - z + log(2 * pi) / 2

to approximate the log Beta function:

log_beta(x, y) ≈ ((x-1/2) * log(x) + (y-1/2) * log(y)
                  - (x+y-1/2) * log(x+y) + log(2*pi)/2)

The approximation additionally improves accuracy near zero by iteratively shifting the log Gamma approximation using the recursion:

lgamma(x) = lgamma(x + 1) - log(x)

If this recursion is applied n times, then absolute error is bounded by error < 0.082 / n < tol, thus we choose n based on the user provided tol.

Parameters
  • x (torch.Tensor) – A positive tensor.

  • y (torch.Tensor) – A positive tensor.

  • tol (float) – Bound on maximum absolute error. Defaults to 0.1. For very small tol, this function simply defers to log_beta().

Return type

torch.Tensor

log_binomial(n, k, tol=0.0)[source]

Computes log binomial coefficient.

When tol >= 0.02 this uses a shifted Stirling’s approximation to the log Beta function via log_beta().

Parameters
Return type

torch.Tensor

log_I1(orders: int, value: torch.Tensor, terms=250)[source]

Compute first n log modified bessel function of first kind .. math

\log(I_v(z)) = v*\log(z/2) + \log(\sum_{k=0}^\inf \exp\left[2*k*\log(z/2) - \sum_kk^k log(kk)
- \lgamma(v + k + 1)\right])
Parameters
  • orders – orders of the log modified bessel function.

  • value – values to compute modified bessel function for

  • terms – truncation of summation

Returns

0 to orders modified bessel function

get_quad_rule(num_quad, prototype_tensor)[source]

Get quadrature points and corresponding log weights for a Gauss Hermite quadrature rule with the specified number of quadrature points.

Example usage:

quad_points, log_weights = get_quad_rule(32, prototype_tensor)
# transform to N(0, 4.0) Normal distribution
quad_points *= 4.0
# compute variance integral in log-space using logsumexp and exponentiate
variance = torch.logsumexp(quad_points.pow(2.0).log() + log_weights, axis=0).exp()
assert (variance - 16.0).abs().item() < 1.0e-6
Parameters
  • num_quad (int) – number of quadrature points.

  • prototype_tensor (torch.Tensor) – used to determine dtype and device of returned tensors.

Returns

tuple of torch.Tensor`s of the form `(quad_points, log_weights)

sparse_multinomial_likelihood(total_count, nonzero_logits, nonzero_value)[source]

The following are equivalent:

# Version 1. dense
log_prob = Multinomial(logits=logits).log_prob(value).sum()

# Version 2. sparse
nnz = value.nonzero(as_tuple=True)
log_prob = sparse_multinomial_likelihood(
    value.sum(-1),
    (logits - logits.logsumexp(-1))[nnz],
    value[nnz],
)

Tensor Utilities

as_complex(x)[source]

Similar to torch.view_as_complex() but copies data in case strides are not multiples of two.

block_diag_embed(mat)[source]

Takes a tensor of shape (…, B, M, N) and returns a block diagonal tensor of shape (…, B x M, B x N).

Parameters

mat (torch.Tensor) – an input tensor with 3 or more dimensions

Returns torch.Tensor

a block diagonal tensor with dimension m.dim() - 1

block_diagonal(mat, block_size)[source]

Takes a block diagonal tensor of shape (…, B x M, B x N) and returns a tensor of shape (…, B, M, N).

Parameters
  • mat (torch.Tensor) – an input tensor with 2 or more dimensions

  • block_size (int) – the number of blocks B.

Returns torch.Tensor

a tensor with dimension mat.dim() + 1

periodic_repeat(tensor, size, dim)[source]

Repeat a period-sized tensor up to given size. For example:

>>> x = torch.tensor([[1, 2, 3], [4, 5, 6]])
>>> periodic_repeat(x, 4, 0)
tensor([[1, 2, 3],
        [4, 5, 6],
        [1, 2, 3],
        [4, 5, 6]])
>>> periodic_repeat(x, 4, 1)
tensor([[1, 2, 3, 1],
        [4, 5, 6, 4]])

This is useful for computing static seasonality in time series models.

Parameters
  • tensor (torch.Tensor) – A tensor of differences.

  • size (int) – Desired size of the result along dimension dim.

  • dim (int) – The tensor dimension along which to repeat.

periodic_cumsum(tensor, period, dim)[source]

Compute periodic cumsum along a given dimension. For example if dim=0:

for t in range(period):
    assert result[t] == tensor[t]
for t in range(period, len(tensor)):
    assert result[t] == tensor[t] + result[t - period]

This is useful for computing drifting seasonality in time series models.

Parameters
  • tensor (torch.Tensor) – A tensor of differences.

  • period (int) – The period of repetition.

  • dim (int) – The tensor dimension along which to accumulate.

periodic_features(duration, max_period=None, min_period=None, **options)[source]

Create periodic (sin,cos) features from max_period down to min_period.

This is useful in time series models where long uneven seasonality can be treated via regression. When only max_period is specified this generates periodic features at all length scales. When also min_period is specified this generates periodic features at large length scales, but omits high frequency features. This is useful when combining regression for long seasonality with other techniques like periodic_repeat() and periodic_cumsum() for short time scales. For example, to combine regress yearly seasonality down to the scale of one week one could set max_period=365.25 and min_period=7.

Parameters
  • duration (int) – Number of discrete time steps.

  • max_period (float) – Optional max period, defaults to duration.

  • min_period (float) – Optional min period (exclusive), defaults to 2 = Nyquist cutoff.

  • **options – Tensor construction options, e.g. dtype and device.

Returns

A (duration, 2 * ceil(max_period / min_period) - 2)-shaped tensor of features normalized to lie in [-1,1].

Return type

Tensor

next_fast_len(size)[source]

Returns the next largest number n >= size whose prime factors are all 2, 3, or 5. These sizes are efficient for fast fourier transforms. Equivalent to scipy.fftpack.next_fast_len().

Parameters

size (int) – A positive number.

Returns

A possibly larger number.

Rtype int

convolve(signal, kernel, mode='full')[source]

Computes the 1-d convolution of signal by kernel using FFTs. The two arguments should have the same rightmost dim, but may otherwise be arbitrarily broadcastable.

Parameters
  • signal (torch.Tensor) – A signal to convolve.

  • kernel (torch.Tensor) – A convolution kernel.

  • mode (str) – One of: ‘full’, ‘valid’, ‘same’.

Returns

A tensor with broadcasted shape. Letting m = signal.size(-1) and n = kernel.size(-1), the rightmost size of the result will be: m + n - 1 if mode is ‘full’; max(m, n) - min(m, n) + 1 if mode is ‘valid’; or max(m, n) if mode is ‘same’.

Rtype torch.Tensor

repeated_matmul(M, n)[source]

Takes a batch of matrices M as input and returns the stacked result of doing the n-many matrix multiplications \(M\), \(M^2\), …, \(M^n\). Parallel cost is logarithmic in n.

Parameters
  • M (torch.Tensor) – A batch of square tensors of shape (…, N, N).

  • n (int) – The order of the largest product \(M^n\)

Returns torch.Tensor

A batch of square tensors of shape (n, …, N, N)

dct(x, dim=- 1)[source]

Discrete cosine transform of type II, scaled to be orthonormal.

This is the inverse of idct_ii() , and is equivalent to scipy.fftpack.dct() with norm="ortho".

Parameters
  • x (Tensor) – The input signal.

  • dim (int) – Dimension along which to compute DCT.

Return type

Tensor

idct(x, dim=- 1)[source]

Inverse discrete cosine transform of type II, scaled to be orthonormal.

This is the inverse of dct_ii() , and is equivalent to scipy.fftpack.idct() with norm="ortho".

Parameters
  • x (Tensor) – The input signal.

  • dim (int) – Dimension along which to compute DCT.

Return type

Tensor

haar_transform(x)[source]

Discrete Haar transform.

Performs a Haar transform along the final dimension. This is the inverse of inverse_haar_transform().

Parameters

x (Tensor) – The input signal.

Return type

Tensor

inverse_haar_transform(x)[source]

Performs an inverse Haar transform along the final dimension. This is the inverse of haar_transform().

Parameters

x (Tensor) – The input signal.

Return type

Tensor

safe_cholesky(x)[source]
cholesky_solve(x, y)[source]
matmul(x, y)[source]
matvecmul(x, y)[source]
triangular_solve(x, y, upper=False, transpose=False)[source]
precision_to_scale_tril(P)[source]
safe_normalize(x, *, p=2)[source]

Safely project a vector onto the sphere wrt the p-norm. This avoids the singularity at zero by mapping zero to the vector [1, 0, 0, ..., 0].

Parameters
  • x (torch.Tensor) – A vector

  • p (float) – The norm exponent, defaults to 2 i.e. the Euclidean norm.

Returns

A normalized version x / ||x||_p.

Return type

Tensor

broadcast_tensors_without_dim(tensors, dim)[source]

Broadcast tensors to the same shape without changing the size of dimension dim of each tensor.

The broadcasting is performed in the same way as done in torch.broadcast_tensors(), while leaving the size of dimension dim of each tensor unchanged.

The returned tensors can be concatenated along the dimension dim.

Parameters
  • tensors (list) – List of torch.Tensor objects.

  • dim (int) – Dimension to leave out of broadcasting.

Returns

List of torch.Tensor objects.

Tensor Indexing

index(tensor, args)[source]

Indexing with nested tuples.

See also the convenience wrapper Index.

This is useful for writing indexing code that is compatible with multiple interpretations, e.g. scalar evaluation, vectorized evaluation, or reshaping.

For example suppose x is a parameter with x.dim() == 2 and we wish to generalize the expression x[..., t] where t can be any of:

  • a scalar t=1 as in x[..., 1];

  • a slice t=slice(None) equivalent to x[..., :]; or

  • a reshaping operation t=(Ellipsis, None) equivalent to x.unsqueeze(-1).

While naive indexing would work for the first two , the third example would result in a nested tuple (Ellipsis, (Ellipsis, None)). This helper flattens that nested tuple and combines consecutive Ellipsis.

Parameters
  • tensor (torch.Tensor) – A tensor to be indexed.

  • args (tuple) – An index, as args to __getitem__.

Returns

A flattened interpetation of tensor[args].

Return type

torch.Tensor

class Index(tensor)[source]

Bases: object

Convenience wrapper around index().

The following are equivalent:

Index(x)[..., i, j, :]
index(x, (Ellipsis, i, j, slice(None)))
Parameters

tensor (torch.Tensor) – A tensor to be indexed.

Returns

An object with a special __getitem__() method.

vindex(tensor, args)[source]

Vectorized advanced indexing with broadcasting semantics.

See also the convenience wrapper Vindex.

This is useful for writing indexing code that is compatible with batching and enumeration, especially for selecting mixture components with discrete random variables.

For example suppose x is a parameter with x.dim() == 3 and we wish to generalize the expression x[i, :, j] from integer i,j to tensors i,j with batch dims and enum dims (but no event dims). Then we can write the generalize version using Vindex

xij = Vindex(x)[i, :, j]

batch_shape = broadcast_shape(i.shape, j.shape)
event_shape = (x.size(1),)
assert xij.shape == batch_shape + event_shape

To handle the case when x may also contain batch dimensions (e.g. if x was sampled in a plated context as when using vectorized particles), vindex() uses the special convention that Ellipsis denotes batch dimensions (hence ... can appear only on the left, never in the middle or in the right). Suppose x has event dim 3. Then we can write:

old_batch_shape = x.shape[:-3]
old_event_shape = x.shape[-3:]

xij = Vindex(x)[..., i, :, j]   # The ... denotes unknown batch shape.

new_batch_shape = broadcast_shape(old_batch_shape, i.shape, j.shape)
new_event_shape = (x.size(1),)
assert xij.shape = new_batch_shape + new_event_shape

Note that this special handling of Ellipsis differs from the NEP [1].

Formally, this function assumes:

  1. Each arg is either Ellipsis, slice(None), an integer, or a batched torch.LongTensor (i.e. with empty event shape). This function does not support Nontrivial slices or torch.BoolTensor masks. Ellipsis can only appear on the left as args[0].

  2. If args[0] is not Ellipsis then tensor is not batched, and its event dim is equal to len(args).

  3. If args[0] is Ellipsis then tensor is batched and its event dim is equal to len(args[1:]). Dims of tensor to the left of the event dims are considered batch dims and will be broadcasted with dims of tensor args.

Note that if none of the args is a tensor with .dim() > 0, then this function behaves like standard indexing:

if not any(isinstance(a, torch.Tensor) and a.dim() for a in args):
    assert Vindex(x)[args] == x[args]

References

[1] https://www.numpy.org/neps/nep-0021-advanced-indexing.html

introduces vindex as a helper for vectorized indexing. The Pyro implementation is similar to the proposed notation x.vindex[] except for slightly different handling of Ellipsis.

Parameters
  • tensor (torch.Tensor) – A tensor to be indexed.

  • args (tuple) – An index, as args to __getitem__.

Returns

A nonstandard interpetation of tensor[args].

Return type

torch.Tensor

class Vindex(tensor)[source]

Bases: object

Convenience wrapper around vindex().

The following are equivalent:

Vindex(x)[..., i, j, :]
vindex(x, (Ellipsis, i, j, slice(None)))
Parameters

tensor (torch.Tensor) – A tensor to be indexed.

Returns

An object with a special __getitem__() method.

Tensor Contraction

contract(equation, *operands, **kwargs)[source]

Wrapper around opt_einsum.contract() that optionally uses Pyro’s cheap optimizer and optionally caches contraction paths.

Parameters

cache_path (bool) – whether to cache the contraction path. Defaults to True.

contract_expression(equation, *shapes, **kwargs)[source]

Wrapper around opt_einsum.contract_expression() that optionally uses Pyro’s cheap optimizer and optionally caches contraction paths.

Parameters

cache_path (bool) – whether to cache the contraction path. Defaults to True.

einsum(equation, *operands, **kwargs)[source]

Generalized plated sum-product algorithm via tensor variable elimination.

This generalizes contract() in two ways:

  1. Multiple outputs are allowed, and intermediate results can be shared.

  2. Inputs and outputs can be plated along symbols given in plates; reductions along plates are product reductions.

The best way to understand this function is to try the examples below, which show how einsum() calls can be implemented as multiple calls to contract() (which is generally more expensive).

To illustrate multiple outputs, note that the following are equivalent:

z1, z2, z3 = einsum('ab,bc->a,b,c', x, y)  # multiple outputs

z1 = contract('ab,bc->a', x, y)
z2 = contract('ab,bc->b', x, y)
z3 = contract('ab,bc->c', x, y)

To illustrate plated inputs, note that the following are equivalent:

assert len(x) == 3 and len(y) == 3
z = einsum('ab,ai,bi->b', w, x, y, plates='i')

z = contract('ab,a,a,a,b,b,b->b', w, *x, *y)

When a sum dimension a always appears with a plate dimension i, then a corresponds to a distinct symbol for each slice of a. Thus the following are equivalent:

assert len(x) == 3 and len(y) == 3
z = einsum('ai,ai->', x, y, plates='i')

z = contract('a,b,c,a,b,c->', *x, *y)

When such a sum dimension appears in the output, it must be accompanied by all of its plate dimensions, e.g. the following are equivalent:

assert len(x) == 3 and len(y) == 3
z = einsum('abi,abi->bi', x, y, plates='i')

z0 = contract('ab,ac,ad,ab,ac,ad->b', *x, *y)
z1 = contract('ab,ac,ad,ab,ac,ad->c', *x, *y)
z2 = contract('ab,ac,ad,ab,ac,ad->d', *x, *y)
z = torch.stack([z0, z1, z2])

Note that each plate slice through the output is multilinear in all plate slices through all inptus, thus e.g. batch matrix multiply would be implemented without plates, so the following are all equivalent:

xy = einsum('abc,acd->abd', x, y, plates='')
xy = torch.stack([xa.mm(ya) for xa, ya in zip(x, y)])
xy = torch.bmm(x, y)

Among all valid equations, some computations are polynomial in the sizes of the input tensors and other computations are exponential in the sizes of the input tensors. This function raises NotImplementedError whenever the computation is exponential.

Parameters
  • equation (str) – An einsum equation, optionally with multiple outputs.

  • operands (torch.Tensor) – A collection of tensors.

  • plates (str) – An optional string of plate symbols.

  • backend (str) – An optional einsum backend, defaults to ‘torch’.

  • cache (dict) – An optional shared_intermediates() cache.

  • modulo_total (bool) – Optionally allow einsum to arbitrarily scale each result plate, which can significantly reduce computation. This is safe to set whenever each result plate denotes a nonnormalized probability distribution whose total is not of interest.

Returns

a tuple of tensors of requested shape, one entry per output.

Return type

tuple

Raises
  • ValueError – if tensor sizes mismatch or an output requests a plated dim without that dim’s plates.

  • NotImplementedError – if contraction would have cost exponential in the size of any input tensor.

ubersum(equation, *operands, **kwargs)[source]

Deprecated, use einsum() instead.

Gaussian Contraction

class Gaussian(log_normalizer: torch.Tensor, info_vec: torch.Tensor, precision: torch.Tensor)[source]

Bases: object

Non-normalized Gaussian distribution.

This represents an arbitrary semidefinite quadratic function, which can be interpreted as a rank-deficient scaled Gaussian distribution. The precision matrix may have zero eigenvalues, thus it may be impossible to work directly with the covariance matrix.

Parameters
  • log_normalizer (torch.Tensor) – a normalization constant, which is mainly used to keep track of normalization terms during contractions.

  • info_vec (torch.Tensor) – information vector, which is a scaled version of the mean info_vec = precision @ mean. We use this represention to make gaussian contraction fast and stable.

  • precision (torch.Tensor) – precision matrix of this gaussian.

dim()[source]
property batch_shape
expand(batch_shape) pyro.ops.gaussian.Gaussian[source]
reshape(batch_shape) pyro.ops.gaussian.Gaussian[source]
__getitem__(index) pyro.ops.gaussian.Gaussian[source]

Index into the batch_shape of a Gaussian.

static cat(parts, dim=0) pyro.ops.gaussian.Gaussian[source]

Concatenate a list of Gaussians along a given batch dimension.

event_pad(left=0, right=0) pyro.ops.gaussian.Gaussian[source]

Pad along event dimension.

event_permute(perm) pyro.ops.gaussian.Gaussian[source]

Permute along event dimension.

__add__(other: Union[pyro.ops.gaussian.Gaussian, int, float, torch.Tensor]) pyro.ops.gaussian.Gaussian[source]

Adds two Gaussians in log-density space.

log_density(value: torch.Tensor) torch.Tensor[source]

Evaluate the log density of this Gaussian at a point value:

-0.5 * value.T @ precision @ value + value.T @ info_vec + log_normalizer

This is mainly used for testing.

rsample(sample_shape=torch.Size([]), noise: Optional[torch.Tensor] = None) torch.Tensor[source]

Reparameterized sampler.

condition(value: torch.Tensor) pyro.ops.gaussian.Gaussian[source]

Condition this Gaussian on a trailing subset of its state. This should satisfy:

g.condition(y).dim() == g.dim() - y.size(-1)

Note that since this is a non-normalized Gaussian, we include the density of y in the result. Thus condition() is similar to a functools.partial binding of arguments:

left = x[..., :n]
right = x[..., n:]
g.log_density(x) == g.condition(right).log_density(left)
left_condition(value: torch.Tensor) pyro.ops.gaussian.Gaussian[source]

Condition this Gaussian on a leading subset of its state. This should satisfy:

g.condition(y).dim() == g.dim() - y.size(-1)

Note that since this is a non-normalized Gaussian, we include the density of y in the result. Thus condition() is similar to a functools.partial binding of arguments:

left = x[..., :n]
right = x[..., n:]
g.log_density(x) == g.left_condition(left).log_density(right)
marginalize(left=0, right=0) pyro.ops.gaussian.Gaussian[source]

Marginalizing out variables on either side of the event dimension:

g.marginalize(left=n).event_logsumexp() = g.logsumexp()
g.marginalize(right=n).event_logsumexp() = g.logsumexp()

and for data x:

g.condition(x).event_logsumexp()

= g.marginalize(left=g.dim() - x.size(-1)).log_density(x)

event_logsumexp() torch.Tensor[source]

Integrates out all latent state (i.e. operating on event dimensions).

class AffineNormal(matrix, loc, scale)[source]

Bases: object

Represents a conditional diagonal normal distribution over a random variable Y whose mean is an affine function of a random variable X. The likelihood of X is thus:

AffineNormal(matrix, loc, scale).condition(y).log_density(x)

which is equivalent to:

Normal(x @ matrix + loc, scale).to_event(1).log_prob(y)
Parameters
  • matrix (torch.Tensor) – A transformation from X to Y. Should have rightmost shape (x_dim, y_dim).

  • loc (torch.Tensor) – A constant offset for Y’s mean. Should have rightmost shape (y_dim,).

  • scale (torch.Tensor) – Standard deviation for Y. Should have rightmost shape (y_dim,).

property batch_shape
condition(value)[source]
left_condition(value)[source]

If value.size(-1) == x_dim, this returns a Normal distribution with event_dim=1. After applying this method, the cost to draw a sample is O(y_dim) instead of O(y_dim ** 3).

rsample(sample_shape=torch.Size([]), noise: Optional[torch.Tensor] = None) torch.Tensor[source]

Reparameterized sampler.

to_gaussian()[source]
expand(batch_shape)[source]
reshape(batch_shape)[source]
__getitem__(index)[source]
event_permute(perm)[source]
__add__(other)[source]
marginalize(left=0, right=0)[source]
mvn_to_gaussian(mvn)[source]

Convert a MultivariateNormal distribution to a Gaussian.

Parameters

mvn (MultivariateNormal) – A multivariate normal distribution.

Returns

An equivalent Gaussian object.

Return type

Gaussian

matrix_and_gaussian_to_gaussian(matrix: torch.Tensor, y_gaussian: pyro.ops.gaussian.Gaussian) pyro.ops.gaussian.Gaussian[source]

Constructs a conditional Gaussian for p(y|x) where y - x @ matrix ~ y_gaussian.

Parameters
  • matrix (torch.Tensor) – A right-acting transformation matrix.

  • y_gaussian (Gaussian) – A distribution over noise of y - x@matrix.

Return type

Gaussian

matrix_and_mvn_to_gaussian(matrix, mvn)[source]

Convert a noisy affine function to a Gaussian. The noisy affine function is defined as:

y = x @ matrix + mvn.sample()
Parameters
  • matrix (Tensor) – A matrix with rightmost shape (x_dim, y_dim).

  • mvn (MultivariateNormal) – A multivariate normal distribution.

Returns

A Gaussian with broadcasted batch shape and .dim() == x_dim + y_dim.

Return type

Gaussian

gaussian_tensordot(x: pyro.ops.gaussian.Gaussian, y: pyro.ops.gaussian.Gaussian, dims: int = 0) pyro.ops.gaussian.Gaussian[source]

Computes the integral over two gaussians:

(x @ y)(a,c) = log(integral(exp(x(a,b) + y(b,c)), b)),

where x is a gaussian over variables (a,b), y is a gaussian over variables (b,c), (a,b,c) can each be sets of zero or more variables, and dims is the size of b.

Parameters
  • x – a Gaussian instance

  • y – a Gaussian instance

  • dims – number of variables to contract

sequential_gaussian_tensordot(gaussian: pyro.ops.gaussian.Gaussian) pyro.ops.gaussian.Gaussian[source]

Integrates a Gaussian x whose rightmost batch dimension is time, computes:

x[..., 0] @ x[..., 1] @ ... @ x[..., T-1]
Parameters

gaussian (Gaussian) – A batched Gaussian whose rightmost dimension is time.

Returns

A Markov product of the Gaussian along its time dimension.

Return type

Gaussian

sequential_gaussian_filter_sample(init: pyro.ops.gaussian.Gaussian, trans: pyro.ops.gaussian.Gaussian, sample_shape: Tuple[int, ...] = (), noise: Optional[torch.Tensor] = None) torch.Tensor[source]

Draws a reparameterized sample from a Markov product of Gaussians via parallel-scan forward-filter backward-sample.

Parameters
  • init (Gaussian) – A Gaussian representing an initial state.

  • trans (Gaussian) – A Gaussian representing as series of state transitions, with time as the rightmost batch dimension. This must have twice the event dim as init: trans.dim() == 2 * init.dim().

  • sample_shape (tuple) – An optional extra shape of samples to draw.

  • noise (torch.Tensor) – An optional standard white noise tensor of shape sample_shape + batch_shape + (duration, state_dim), where duration = 1 + trans.batch_shape[-1] is the number of time points to be sampled, and state_dim = init.dim() is the state dimension. This is useful for computing the mean (pass zeros), varying temperature (pass scaled noise), and antithetic sampling (pass cat([z,-z])).

Returns

A reparametrized sample of shape sample_shape + batch_shape + (duration, state_dim).

Return type

torch.Tensor

Statistical Utilities

gelman_rubin(input, chain_dim=0, sample_dim=1)[source]

Computes R-hat over chains of samples. It is required that input.size(sample_dim) >= 2 and input.size(chain_dim) >= 2.

Parameters
  • input (torch.Tensor) – the input tensor.

  • chain_dim (int) – the chain dimension.

  • sample_dim (int) – the sample dimension.

Returns torch.Tensor

R-hat of input.

split_gelman_rubin(input, chain_dim=0, sample_dim=1)[source]

Computes R-hat over chains of samples. It is required that input.size(sample_dim) >= 4.

Parameters
  • input (torch.Tensor) – the input tensor.

  • chain_dim (int) – the chain dimension.

  • sample_dim (int) – the sample dimension.

Returns torch.Tensor

split R-hat of input.

autocorrelation(input, dim=0)[source]

Computes the autocorrelation of samples at dimension dim.

Reference: https://en.wikipedia.org/wiki/Autocorrelation#Efficient_computation

Parameters
  • input (torch.Tensor) – the input tensor.

  • dim (int) – the dimension to calculate autocorrelation.

Returns torch.Tensor

autocorrelation of input.

autocovariance(input, dim=0)[source]

Computes the autocovariance of samples at dimension dim.

Parameters
  • input (torch.Tensor) – the input tensor.

  • dim (int) – the dimension to calculate autocorrelation.

Returns torch.Tensor

autocorrelation of input.

effective_sample_size(input, chain_dim=0, sample_dim=1)[source]

Computes effective sample size of input.

Reference:

[1] Introduction to Markov Chain Monte Carlo,

Charles J. Geyer

[2] Stan Reference Manual version 2.18,

Stan Development Team

Parameters
  • input (torch.Tensor) – the input tensor.

  • chain_dim (int) – the chain dimension.

  • sample_dim (int) – the sample dimension.

Returns torch.Tensor

effective sample size of input.

resample(input, num_samples, dim=0, replacement=False)[source]

Draws num_samples samples from input at dimension dim.

Parameters
  • input (torch.Tensor) – the input tensor.

  • num_samples (int) – the number of samples to draw from input.

  • dim (int) – dimension to draw from input.

Returns torch.Tensor

samples drawn randomly from input.

quantile(input, probs, dim=0)[source]

Computes quantiles of input at probs. If probs is a scalar, the output will be squeezed at dim.

Parameters
  • input (torch.Tensor) – the input tensor.

  • probs (list) – quantile positions.

  • dim (int) – dimension to take quantiles from input.

Returns torch.Tensor

quantiles of input at probs.

weighed_quantile(input: torch.Tensor, probs: Union[List[float], Tuple[float, ...], torch.Tensor], log_weights: torch.Tensor, dim: int = 0) torch.Tensor[source]

Computes quantiles of weighed input samples at probs.

Parameters
  • input (torch.Tensor) – the input tensor.

  • probs (list) – quantile positions.

  • log_weights (torch.Tensor) – sample weights tensor.

  • dim (int) – dimension to take quantiles from input.

Returns torch.Tensor

quantiles of input at probs.

Example:

>>> from pyro.ops.stats import weighed_quantile
>>> import torch
>>> input = torch.Tensor([[10, 50, 40], [20, 30, 0]])
>>> probs = torch.Tensor([0.2, 0.8])
>>> log_weights = torch.Tensor([0.4, 0.5, 0.1]).log()
>>> result = weighed_quantile(input, probs, log_weights, -1)
>>> torch.testing.assert_close(result, torch.Tensor([[40.4, 47.6], [9.0, 26.4]]))
pi(input, prob, dim=0)[source]

Computes percentile interval which assigns equal probability mass to each tail of the interval.

Parameters
  • input (torch.Tensor) – the input tensor.

  • prob (float) – the probability mass of samples within the interval.

  • dim (int) – dimension to calculate percentile interval from input.

Returns torch.Tensor

quantiles of input at probs.

hpdi(input, prob, dim=0)[source]

Computes “highest posterior density interval” which is the narrowest interval with probability mass prob.

Parameters
  • input (torch.Tensor) – the input tensor.

  • prob (float) – the probability mass of samples within the interval.

  • dim (int) – dimension to calculate percentile interval from input.

Returns torch.Tensor

quantiles of input at probs.

waic(input, log_weights=None, pointwise=False, dim=0)[source]

Computes “Widely Applicable/Watanabe-Akaike Information Criterion” (WAIC) and its corresponding effective number of parameters.

Reference:

[1] WAIC and cross-validation in Stan, Aki Vehtari, Andrew Gelman

Parameters
  • input (torch.Tensor) – the input tensor, which is log likelihood of a model.

  • log_weights (torch.Tensor) – weights of samples along dim.

  • dim (int) – the sample dimension of input.

Returns tuple

tuple of WAIC and effective number of parameters.

fit_generalized_pareto(X)[source]

Given a dataset X assumed to be drawn from the Generalized Pareto Distribution, estimate the distributional parameters k, sigma using a variant of the technique described in reference [1], as described in reference [2].

References [1] ‘A new and efficient estimation method for the generalized Pareto distribution.’ Zhang, J. and Stephens, M.A. (2009). [2] ‘Pareto Smoothed Importance Sampling.’ Aki Vehtari, Andrew Gelman, Jonah Gabry

Parameters

torch.Tensor – the input data X

Returns tuple

tuple of floats (k, sigma) corresponding to the fit parameters

crps_empirical(pred, truth)[source]

Computes negative Continuous Ranked Probability Score CRPS* [1] between a set of samples pred and true data truth. This uses an n log(n) time algorithm to compute a quantity equal that would naively have complexity quadratic in the number of samples n:

CRPS* = E|pred - truth| - 1/2 E|pred - pred'|
      = (pred - truth).abs().mean(0)
      - (pred - pred.unsqueeze(1)).abs().mean([0, 1]) / 2

Note that for a single sample this reduces to absolute error.

References

[1] Tilmann Gneiting, Adrian E. Raftery (2007)

Strictly Proper Scoring Rules, Prediction, and Estimation https://www.stat.washington.edu/raftery/Research/PDF/Gneiting2007jasa.pdf

Parameters
  • pred (torch.Tensor) – A set of sample predictions batched on rightmost dim. This should have shape (num_samples,) + truth.shape.

  • truth (torch.Tensor) – A tensor of true observations.

Returns

A tensor of shape truth.shape.

Return type

torch.Tensor

energy_score_empirical(pred: torch.Tensor, truth: torch.Tensor, pred_batch_size: Optional[int] = None) torch.Tensor[source]

Computes negative Energy Score ES* (see equation 22 in [1]) between a set of multivariate samples pred and a true data vector truth. Running time is quadratic in the number of samples n. In case of univariate samples the output coincides with the CRPS:

ES* = E|pred - truth| - 1/2 E|pred - pred'|

Note that for a single sample this reduces to the Euclidean norm of the difference between the sample pred and the truth.

This is a strictly proper score so that for pred distirbuted according to a distribution \(P\) and truth distributed according to a distribution \(Q\) we have \(ES^{*}(P,Q) \ge ES^{*}(Q,Q)\) with equality holding if and only if \(P=Q\), i.e. if \(P\) and \(Q\) have the same multivariate distribution (it is not sufficient for \(P\) and \(Q\) to have the same marginals in order for equality to hold).

References

[1] Tilmann Gneiting, Adrian E. Raftery (2007)

Strictly Proper Scoring Rules, Prediction, and Estimation https://www.stat.washington.edu/raftery/Research/PDF/Gneiting2007jasa.pdf

Parameters
  • pred (torch.Tensor) – A set of sample predictions batched on the second leftmost dim. The leftmost dim is that of the multivariate sample.

  • truth (torch.Tensor) – A tensor of true observations with same shape as pred except for the second leftmost dim which can have any value or be omitted.

  • pred_batch_size (int) – If specified the predictions will be batched before calculation according to the specified batch size in order to reduce memory consumption.

Returns

A tensor of shape truth.shape.

Return type

torch.Tensor

Streaming Statistics

class StreamingStats[source]

Bases: abc.ABC

Abstract base class for streamable statistics of trees of tensors.

Derived classes must implelement update(), merge(), and get().

abstract update(sample) None[source]

Update state from a single sample.

This mutates self and returns nothing. Updates should be independent of order, i.e. samples should be exchangeable.

Parameters

sample – A sample value which is a nested dictionary of torch.Tensor leaves. This can have arbitrary nesting and shape shape, but assumes shape is constant across calls to .update().

abstract merge(other) pyro.ops.streaming.StreamingStats[source]

Select two aggregate statistics, e.g. from different MCMC chains.

This is a pure function: it returns a new StreamingStats object and does not modify either self or other.

Parameters

other – Another streaming stats instance of the same type.

abstract get() Any[source]

Return the aggregate statistic.

class StatsOfDict(types: Dict[Hashable, Callable[[], pyro.ops.streaming.StreamingStats]] = {}, default: Callable[[], pyro.ops.streaming.StreamingStats] = <class 'pyro.ops.streaming.CountStats'>)[source]

Bases: pyro.ops.streaming.StreamingStats

Statistics of samples that are dictionaries with constant set of keys.

For example the following are equivalent:

# Version 1. Hand encode statistics.
>>> a_stats = CountStats()
>>> b_stats = CountMeanStats()
>>> a_stats.update(torch.tensor(0.))
>>> b_stats.update(torch.tensor([1., 2.]))
>>> summary = {"a": a_stats.get(), "b": b_stats.get()}

# Version 2. Collect samples into dictionaries.
>>> stats = StatsOfDict({"a": CountStats, "b": CountMeanStats})
>>> stats.update({"a": torch.tensor(0.), "b": torch.tensor([1., 2.])})
>>> summary = stats.get()
>>> summary
{'a': {'count': 1}, 'b': {'count': 1, 'mean': tensor([1., 2.])}}
Parameters
  • default – Default type of statistics of values of the dictionary. Defaults to the inexpensive CountStats.

  • types (dict) – Dictionary mapping key to type of statistic that should be recorded for values corresponding to that key.

update(sample: Dict[Hashable, Any]) None[source]
merge(other: pyro.ops.streaming.StatsOfDict) pyro.ops.streaming.StatsOfDict[source]
get() Dict[Hashable, Any][source]
Returns

A dictionary of statistics. The keys of this dictionary are the same as the keys of the samples from which this object is updated.

Return type

dict

class StackStats[source]

Bases: pyro.ops.streaming.StreamingStats

Statistic collecting a stream of tensors into a single stacked tensor.

update(sample: torch.Tensor) None[source]
merge(other: pyro.ops.streaming.StackStats) pyro.ops.streaming.StackStats[source]
get() Dict[str, Union[int, torch.Tensor]][source]
Returns

A dictionary with keys count: int and (if any samples have been collected) samples: torch.Tensor.

Return type

dict

class CountStats[source]

Bases: pyro.ops.streaming.StreamingStats

Statistic tracking only the number of samples.

For example:

>>> stats = CountStats()
>>> stats.update(torch.randn(3, 3))
>>> stats.get()
{'count': 1}
update(sample) None[source]
merge(other: pyro.ops.streaming.CountStats) pyro.ops.streaming.CountStats[source]
get() Dict[str, int][source]
Returns

A dictionary with keys count: int.

Return type

dict

class CountMeanStats[source]

Bases: pyro.ops.streaming.StreamingStats

Statistic tracking the count and mean of a single torch.Tensor.

update(sample: torch.Tensor) None[source]
merge(other: pyro.ops.streaming.CountMeanStats) pyro.ops.streaming.CountMeanStats[source]
get() Dict[str, Union[int, torch.Tensor]][source]
Returns

A dictionary with keys count: int and (if any samples have been collected) mean: torch.Tensor.

Return type

dict

class CountMeanVarianceStats[source]

Bases: pyro.ops.streaming.StreamingStats

Statistic tracking the count, mean, and (diagonal) variance of a single torch.Tensor.

update(sample: torch.Tensor) None[source]
merge(other: pyro.ops.streaming.CountMeanVarianceStats) pyro.ops.streaming.CountMeanVarianceStats[source]
get() Dict[str, Union[int, torch.Tensor]][source]
Returns

A dictionary with keys count: int and (if any samples have been collected) mean: torch.Tensor and variance: torch.Tensor.

Return type

dict

State Space Model and GP Utilities

class MaternKernel(nu=1.5, num_gps=1, length_scale_init=None, kernel_scale_init=None)[source]

Bases: pyro.nn.module.PyroModule

Provides the building blocks for representing univariate Gaussian Processes (GPs) with Matern kernels as state space models.

Parameters
  • nu (float) – The order of the Matern kernel (one of 0.5, 1.5 or 2.5)

  • num_gps (int) – the number of GPs

  • length_scale_init (torch.Tensor) – optional num_gps-dimensional vector of initializers for the length scale

  • kernel_scale_init (torch.Tensor) – optional num_gps-dimensional vector of initializers for the kernel scale

References

[1] Kalman Filtering and Smoothing Solutions to Temporal Gaussian Process Regression Models,

Jouni Hartikainen and Simo Sarkka.

[2] Stochastic Differential Equation Methods for Spatio-Temporal Gaussian Process Regression,

Arno Solin.

transition_matrix(dt)[source]

Compute the (exponentiated) transition matrix of the GP latent space. The resulting matrix has layout (num_gps, old_state, new_state), i.e. this matrix multiplies states from the right.

See section 5 in reference [1] for details.

Parameters

dt (float) – the time interval over which the GP latent space evolves.

Returns torch.Tensor

a 3-dimensional tensor of transition matrices of shape (num_gps, state_dim, state_dim).

stationary_covariance()[source]

Compute the stationary state covariance. See Eqn. 3.26 in reference [2].

Returns torch.Tensor

a 3-dimensional tensor of covariance matrices of shape (num_gps, state_dim, state_dim).

process_covariance(A)[source]

Given a transition matrix A computed with transition_matrix compute the the process covariance as described in Eqn. 3.11 in reference [2].

Returns torch.Tensor

a batched covariance matrix of shape (num_gps, state_dim, state_dim)

transition_matrix_and_covariance(dt)[source]

Get the transition matrix and process covariance corresponding to a time interval dt.

Parameters

dt (float) – the time interval over which the GP latent space evolves.

Returns tuple

(transition_matrix, process_covariance) both 3-dimensional tensors of shape (num_gps, state_dim, state_dim)

training: bool