# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
import functools
import weakref
from collections import namedtuple
import torch
from torch.distributions import constraints
from pyro.distributions.util import broadcast_shape, is_validation_enabled
from pyro.ops.special import safe_log
from .torch_distribution import TorchDistribution
class CoalescentTimesConstraint(constraints.Constraint):
def __init__(self, leaf_times, *, ordered=True):
self.leaf_times = leaf_times
self.ordered = ordered
def check(self, value):
# There must always at least one lineage.
coal_times = value
phylogeny = _make_phylogeny(self.leaf_times, coal_times)
at_least_one_lineage = (phylogeny.lineages > 0).all(dim=-1)
if not self.ordered:
return at_least_one_lineage
# Inputs must be ordered.
ordered = (value[..., :-1] <= value[..., 1:]).all(dim=-1)
return ordered & at_least_one_lineage
[docs]class CoalescentTimes(TorchDistribution):
Distribution over sorted coalescent times given irregular sampled
``leaf_times`` and constant population size.
Sample values will be **sorted** sets of binary coalescent times. Each
sample ``value`` will have cardinality ``value.size(-1) =
leaf_times.size(-1) - 1``, so that phylogenies are complete binary trees.
This distribution can thus be batched over multiple samples of phylogenies
given fixed (number of) leaf times, e.g. over phylogeny samples from BEAST
or MrBayes.
[1] J.F.C. Kingman (1982)
"On the Genealogy of Large Populations"
Journal of Applied Probability
[2] J.F.C. Kingman (1982)
"The Coalescent"
Stochastic Processes and their Applications
:param torch.Tensor leaf_times: Vector of times of sampling events, i.e.
leaf nodes in the phylogeny. These can be arbitrary real numbers with
arbitrary order and duplicates.
:param torch.Tensor rate: Base coalescent rate (pairwise rate of
coalescence) under a constant population size model. Defaults to 1.
arg_constraints = {"leaf_times": constraints.real, "rate": constraints.positive}
def __init__(self, leaf_times, rate=1.0, *, validate_args=None):
rate = torch.as_tensor(rate, dtype=leaf_times.dtype, device=leaf_times.device)
batch_shape = broadcast_shape(rate.shape, leaf_times.shape[:-1])
event_shape = (leaf_times.size(-1) - 1,)
self.leaf_times = leaf_times
self.rate = rate
super().__init__(batch_shape, event_shape, validate_args=validate_args)
def support(self):
return CoalescentTimesConstraint(self.leaf_times)
[docs] def log_prob(self, value):
if self._validate_args:
coal_times = value
phylogeny = _make_phylogeny(self.leaf_times, coal_times)
# The coalescent process is like a Poisson process with rate binomial
# in the number of lineages, which changes at each event.
binomial = phylogeny.binomial[..., :-1]
interval = phylogeny.times[..., :-1] - phylogeny.times[..., 1:]
log_prob = self.rate.log() * coal_times.size(-1) - self.rate * (
binomial * interval
# Scaling by those rates and accounting for log|jacobian|, the density
# is that of a collection of independent Exponential intervals.
log_abs_det_jacobian = phylogeny.coal_binomial.log().sum(-1).neg()
return log_prob - log_abs_det_jacobian
[docs] def sample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)[:-1]
leaf_times = self.leaf_times.expand(shape + (-1,))
return _sample_coalescent_times(leaf_times)
[docs]class CoalescentTimesWithRate(TorchDistribution):
Distribution over coalescent times given irregular sampled ``leaf_times``
and piecewise constant coalescent rates defined on a regular time grid.
This assumes a piecewise constant base coalescent rate specified on time
intervals ``(-inf,1]``, ``[1,2]``, ..., ``[T-1,inf)``, where ``T =
rate_grid.size(-1)``. Leaves may be sampled at arbitrary real times, but
are commonly sampled in the interval ``[0, T]``.
Sample values will be sorted sets of binary coalescent times. Each sample
``value`` will have cardinality ``value.size(-1) = leaf_times.size(-1) -
1``, so that phylogenies are complete binary trees. This distribution can
thus be batched over multiple samples of phylogenies given fixed (number
of) leaf times, e.g. over phylogeny samples from BEAST or MrBayes.
This distribution implements :meth:`log_prob` but not ``.sample()``.
See also :class:`~pyro.distributions.CoalescentRateLikelihood`.
[1] J.F.C. Kingman (1982)
"On the Genealogy of Large Populations"
Journal of Applied Probability
[2] J.F.C. Kingman (1982)
"The Coalescent"
Stochastic Processes and their Applications
[3] A. Popinga, T. Vaughan, T. Statler, A.J. Drummond (2014)
"Inferring epidemiological dynamics with Bayesian coalescent inference:
The merits of deterministic and stochastic models"
:param torch.Tensor leaf_times: Tensor of times of sampling events, i.e.
leaf nodes in the phylogeny. These can be arbitrary real numbers with
arbitrary order and duplicates.
:param torch.Tensor rate_grid: Tensor of base coalescent rates (pairwise
rate of coalescence). For example in a simple SIR model this might be
``beta S / I``. The rightmost dimension is time, and this tensor
represents a (batch of) rates that are piecewise constant in time.
arg_constraints = {
"leaf_times": constraints.real,
"rate_grid": constraints.positive,
def __init__(self, leaf_times, rate_grid, *, validate_args=None):
batch_shape = broadcast_shape(leaf_times.shape[:-1], rate_grid.shape[:-1])
event_shape = (leaf_times.size(-1) - 1,)
self.leaf_times = leaf_times
self.rate_grid = rate_grid
super().__init__(batch_shape, event_shape, validate_args=validate_args)
def support(self):
return CoalescentTimesConstraint(self.leaf_times)
def duration(self):
return self.rate_grid.size(-1)
[docs] def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(CoalescentTimesWithRate, _instance)
new.leaf_times = self.leaf_times
new.rate_grid = self.rate_grid
super(CoalescentTimesWithRate, new).__init__(
batch_shape, self.event_shape, validate_args=False
new._validate_args = self.__dict__.get("_validate_args")
return new
[docs] def log_prob(self, value):
Computes likelihood as in equations 7-8 of [3].
This has time complexity ``O(T + S N log(N))`` where ``T`` is the
number of time steps, ``N`` is the number of leaves, and ``S =
sample_shape.numel()`` is the number of samples of ``value``.
:param torch.Tensor value: A tensor of coalescent times. These denote
sets of size ``leaf_times.size(-1) - 1`` along the trailing
dimension and should be sorted along that dimension.
:returns: Likelihood ``p(coal_times | leaf_times, rate_grid)``
:rtype: torch.Tensor
if self._validate_args:
coal_times = value
phylogeny = _make_phylogeny(self.leaf_times, coal_times)
# Compute survival factors for closed intervals.
cumsum = self.rate_grid.cumsum(-1)
cumsum = torch.nn.functional.pad(cumsum, (1, 0), value=0)
integral = _interpolate_gather(
cumsum, phylogeny.times[..., 1:]
) # ignore the final lonely leaf
integral = integral[..., :-1] - integral[..., 1:]
integral = integral.clamp(min=torch.finfo(integral.dtype).tiny) # avoid nan
log_prob = -(phylogeny.binomial[..., 1:-1] * integral).sum(-1)
# Compute density of coalescent events.
i = coal_times.floor().clamp(min=0, max=self.duration - 1).long()
rates = phylogeny.coal_binomial * _gather(self.rate_grid, -1, i)
log_prob = log_prob + safe_log(rates).sum(-1)
batch_shape = broadcast_shape(self.batch_shape, value.shape[:-1])
log_prob = log_prob.expand(batch_shape)
return log_prob
[docs]class CoalescentRateLikelihood:
EXPERIMENTAL This is not a :class:`~pyro.distributions.Distribution`, but
acts as a transposed version of :class:`CoalescentTimesWithRate` making the
elements of ``rate_grid`` independent and thus compatible with ``plate``
and ``poutine.markov``. For non-batched inputs the following are all
equivalent likelihoods::
# Version 1.
CoalescentTimesWithRate(leaf_times, rate_grid),
# Version 2. using pyro.plate
likelihood = CoalescentRateLikelihood(leaf_times, coal_times, len(rate_grid))
with pyro.plate("time", len(rate_grid)):
pyro.factor("coalescent", likelihood(rate_grid))
# Version 3. using pyro.markov
likelihood = CoalescentRateLikelihood(leaf_times, coal_times, len(rate_grid))
for t in pyro.markov(range(len(rate_grid))):
pyro.factor("coalescent_{}".format(t), likelihood(rate_grid[t], t))
The third version is useful for e.g.
:class:`~pyro.infer.smcfilter.SMCFilter` where ``rate_grid`` might be
computed sequentially.
:param torch.Tensor leaf_times: Tensor of times of sampling events, i.e.
leaf nodes in the phylogeny. These can be arbitrary real numbers with
arbitrary order and duplicates.
:param torch.Tensor coal_times: A tensor of coalescent times. These denote
sets of size ``leaf_times.size(-1) - 1`` along the trailing dimension
and should be sorted along that dimension.
:param int duration: Size of the rate grid, ``rate_grid.size(-1)``.
def __init__(self, leaf_times, coal_times, duration, *, validate_args=None):
assert leaf_times.size(-1) == 1 + coal_times.size(-1)
assert isinstance(duration, int) and duration >= 2
if validate_args is True or validate_args is None and is_validation_enabled:
constraint = CoalescentTimesConstraint(leaf_times, ordered=False)
if not constraint.check(coal_times).all():
raise ValueError("Invalid (leaf_times, coal_times)")
phylogeny = _make_phylogeny(leaf_times, coal_times)
batch_shape = phylogeny.times.shape[:-1]
new_zeros = leaf_times.new_zeros
new_ones = leaf_times.new_ones
# Construct linear part from intervals of survival outside of [0,duration].
times = phylogeny.times.clamp(max=0)
intervals = times[..., 1:] - times[..., :-1]
pre_linear = (phylogeny.binomial[..., :-1] * intervals).sum(-1, keepdim=True)
times = phylogeny.times.clamp(min=duration)
intervals = times[..., 1:] - times[..., :-1]
post_linear = (phylogeny.binomial[..., :-1] * intervals).sum(-1, keepdim=True)
self._linear = torch.cat(
new_zeros(pre_linear.shape[:-1] + (duration - 2,)),
# Construct linear part from intervals of survival within [0, duration].
times = phylogeny.times.clamp(min=0, max=duration)
sparse_diff = phylogeny.binomial[..., :-1] - phylogeny.binomial[..., 1:]
dense_diff = new_zeros(batch_shape + (1 + duration,))
_interpolate_scatter_add_(dense_diff, times[..., 1:], sparse_diff)
self._linear += dense_diff.flip([-1]).cumsum(-1)[..., :-1].flip([-1])
# Construct const and log part from coalescent events.
coal_index = coal_times.floor().clamp(min=0, max=duration - 1).long()
self._const = new_zeros(batch_shape + (duration,))
self._const.scatter_add_(-1, coal_index, phylogeny.coal_binomial.log())
self._log = new_zeros(batch_shape + (duration,))
self._log.scatter_add_(-1, coal_index, new_ones(coal_index.shape))
[docs] def __call__(self, rate_grid, t=slice(None)):
Computes the likelihood of [1] equations 7-9 for one or all time
[1] A. Popinga, T. Vaughan, T. Statler, A.J. Drummond (2014)
"Inferring epidemiological dynamics with Bayesian coalescent
inference: The merits of deterministic and stochastic models"
:param torch.Tensor rate_grid: Tensor of base coalescent rates
(pairwise rate of coalescence). For example in a simple SIR model
this might be ``beta S / I``. The rightmost dimension is time, and
this tensor represents a (batch of) rates that are piecwise
constant in time.
:param time: Optional time index by which the input was sliced, as in
``rate_grid[..., t]`` This can be an integer for sequential models
or ``slice(None)`` for vectorized models.
:type time: int or slice
:returns: Likelihood ``p(coal_times | leaf_times, rate_grid)``,
or a part of that likelihood corresponding to a single time step.
:rtype: torch.Tensor
const = self._const[..., t]
linear = self._linear[..., t] * rate_grid
log = (
self._log[..., t]
* rate_grid.clamp(min=torch.finfo(rate_grid.dtype).tiny).log()
return const + linear + log
[docs]def bio_phylo_to_times(tree, *, get_time=None):
Extracts coalescent summary statistics from a phylogeny, suitable for use
with :class:`~pyro.distributions.CoalescentRateLikelihood`.
:param Bio.Phylo.BaseTree.Clade tree: A phylogenetic tree.
:param callable get_time: Optional function to extract the time point of
each sub-:class:`~Bio.Phylo.BaseTree.Clade`. If absent, times will be
computed by cumulative `.branch_length`.
:returns: A pair of :class:`~torch.Tensor` s ``(leaf_times, coal_times)``
where ``leaf_times`` are times of sampling events (leaf nodes in the
phylogenetic tree) and ``coal_times`` are times of coalescences (leaf
nodes in the phylogenetic binary tree).
:rtype: tuple
if get_time is None:
# Compute time as cumulative branch length.
def get_branch_length(clade):
branch_length = clade.branch_length
return 1.0 if branch_length is None else branch_length
times = {tree.root: get_branch_length(tree.root)}
leaf_times = []
coal_times = []
for clade in tree.find_clades():
if get_time is None:
time = times[clade]
for child in clade:
times[child] = time + get_branch_length(child)
time = get_time(clade)
num_children = len(clade)
if num_children == 0:
# Pyro expects binary coalescent events, so we split n-ary events
# into n-1 separate binary events.
for _ in range(num_children - 1):
assert len(leaf_times) == 1 + len(coal_times)
leaf_times = torch.tensor(leaf_times)
coal_times = torch.tensor(coal_times)
return leaf_times, coal_times
def _gather(tensor, dim, index):
Like :func:`torch.gather` but broadcasts.
if dim != -1:
raise NotImplementedError
shape = broadcast_shape(tensor.shape[:-1], index.shape[:-1]) + (-1,)
tensor = tensor.expand(shape)
index = index.expand(shape)
return tensor.gather(dim, index)
def _interpolate_gather(array, x):
Like ``torch.gather(-1, array, x)`` but continuously indexes into the
rightmost dim of an array, linearly interpolating between array values.
with torch.no_grad():
x0 = x.floor().clamp(min=0, max=array.size(-1) - 2)
x1 = x0 + 1
f0 = _gather(array, -1, x0.long())
f1 = _gather(array, -1, x1.long())
return f0 * (x1 - x) + f1 * (x - x0)
def _interpolate_scatter_add_(dst, x, src):
Like ``dst.scatter_add_(-1, x, src)`` but continuously index into the
rightmost dim of an array, linearly interpolating between array values.
with torch.no_grad():
x0 = x.floor().clamp(min=0, max=dst.size(-1) - 2)
x1 = x0 + 1
dst.scatter_add_(-1, x0.long(), src * (x1 - x))
dst.scatter_add_(-1, x1.long(), src * (x - x0))
return dst
def _weak_memoize(fn):
cache = {}
def memoized_fn(*args):
key = tuple(map(id, args))
# Allow cache hit only when tensors have not since been mutated.
version = tuple(arg._version for arg in args)
if key in cache:
old_version, result = cache[key]
if old_version == version:
return result
result = fn(*args)
cache[key] = version, result
for arg in args:
weakref.finalize(arg, cache.pop, key, None)
return result
return memoized_fn
# This helper data structure has only timing information.
_Phylogeny = namedtuple(
def _make_phylogeny(leaf_times, coal_times):
assert leaf_times.size(-1) == 1 + coal_times.size(-1)
# Expand shapes to match.
N = leaf_times.size(-1)
batch_shape = broadcast_shape(leaf_times.shape[:-1], coal_times.shape[:-1])
if leaf_times.shape[:-1] != batch_shape:
leaf_times = leaf_times.expand(batch_shape + (N,))
if coal_times.shape[:-1] != batch_shape:
coal_times = coal_times.expand(batch_shape + (N - 1,))
# Combine N sampling events (leaf_times) plus N-1 coalescent events
# (coal_times) into a pair (times, signs) of arrays of length 2N-1, where
# leaf sample sign is +1 and coalescent sign is -1.
times = torch.cat([coal_times, leaf_times], dim=-1)
signs = torch.linspace(
1.5 - N, N - 0.5, 2 * N - 1
).sign() # e.g. [-1, -1, +1, +1, +1]
# Sort the events reverse-ordered in time, i.e. latest to earliest.
times, index = times.sort(dim=-1, descending=True)
signs = signs[index]
inv_index = index.new_empty(index.shape)
inv_index.scatter_(-1, index, torch.arange(2 * N - 1).expand_as(index))
# Compute the number n of lineages preceding each event, then the binomial
# coefficients that will multiply the base coalescence rate.
lineages = signs.cumsum(-1)
binomial = lineages * (lineages - 1) / 2
# Compute the binomial coefficient following each coalescent event.
coal_index = inv_index[..., : N - 1]
coal_binomial = binomial.gather(-1, coal_index - 1)
return _Phylogeny(times, signs, lineages, binomial, coal_binomial)
def _sample_coalescent_times(leaf_times):
leaf_times = leaf_times.detach()
proto = leaf_times
N = leaf_times.size(-1)
batch_shape = leaf_times.shape[:-1]
# We don't bother to implement a version that vectorizes over batches;
# instead we simply sequentially sample and stack.
if batch_shape:
flat_leaf_times = leaf_times.reshape(-1, N)
flat_coal_times = torch.stack(
list(map(_sample_coalescent_times, flat_leaf_times))
return flat_coal_times.reshape(batch_shape + (N - 1,))
assert leaf_times.shape == (N,)
# Sequentially sample coalescent events from latest to earliest.
leaf_times = leaf_times.sort(dim=-1, descending=True).values.tolist()
coal_times = []
# Start with the minimum of two active leaves.
leaf = 1
t = leaf_times[leaf]
active = 2
binomial = active * (active - 1) / 2
for u in proto.new_empty(N - 1).exponential_().tolist():
while leaf + 1 < N and u > (t - leaf_times[leaf + 1]) * binomial:
# Move past the next leaf.
leaf += 1
u -= (t - leaf_times[leaf]) * binomial
t = leaf_times[leaf]
active += 1
binomial = active * (active - 1) / 2
# Add a coalescent event.
t -= u / binomial
active -= 1
binomial = active * (active - 1) / 2
return proto.new_tensor(coal_times)