Source code for pyro.infer.reparam.split

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import torch

import pyro
import pyro.distributions as dist
import pyro.poutine as poutine

from .reparam import Reparam


[docs]class SplitReparam(Reparam): """ Reparameterizer to split a random variable along a dimension, similar to :func:`torch.split`. This is useful for treating different parts of a tensor with different reparameterizers or inference methods. For example when performing HMC inference on a time series, you can first apply :class:`~pyro.infer.reparam.discrete_cosine.DiscreteCosineReparam` or :class:`~pyro.infer.reparam.haar.HaarReparam`, then apply :class:`SplitReparam` to split into low-frequency and high-frequency components, and finally add the low-frequency components to the ``full_mass`` matrix together with globals. :param sections: Size of a single chunk or list of sizes for each chunk. :type: list(int) :param int dim: Dimension along which to split. Defaults to -1. """ def __init__(self, sections, dim): assert isinstance(dim, int) and dim < 0 assert isinstance(sections, list) assert all(isinstance(size, int) for size in sections) self.event_dim = -dim self.sections = sections
[docs] def apply(self, msg): name = msg["name"] fn = msg["fn"] value = msg["value"] is_observed = msg["is_observed"] assert fn.event_dim >= self.event_dim # Split value into parts. value_split = [None] * len(self.sections) if value is not None: value_split[:] = value.split(self.sections, -self.event_dim) # Draw independent parts. dim = fn.event_dim - self.event_dim left_shape = fn.event_shape[:dim] right_shape = fn.event_shape[1 + dim :] for i, size in enumerate(self.sections): event_shape = left_shape + (size,) + right_shape value_split[i] = pyro.sample( f"{name}_split_{i}", dist.ImproperUniform(fn.support, fn.batch_shape, event_shape), obs=value_split[i], infer={"is_observed": is_observed}, ) # Combine parts into value. if value is None: value = torch.cat(value_split, dim=-self.event_dim) if poutine.get_mask() is False: log_density = 0.0 else: log_density = fn.log_prob(value) new_fn = dist.Delta(value, event_dim=fn.event_dim, log_density=log_density) return {"fn": new_fn, "value": value, "is_observed": True}