# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import torch

import pyro
import pyro.distributions as dist

from .reparam import Reparam

[docs]class SplitReparam(Reparam): """ Reparameterizer to split a random variable along a dimension, similar to :func:`torch.split`. This is useful for treating different parts of a tensor with different reparameterizers or inference methods. For example when performing HMC inference on a time series, you can first apply :class:`~pyro.infer.reparam.discrete_cosine.DiscreteCosineReparam` or :class:`~pyro.infer.reparam.haar.HaarReparam`, then apply :class:`SplitReparam` to split into low-frequency and high-frequency components, and finally add the low-frequency components to the ``full_mass`` matrix together with globals. :param sections: Size of a single chunk or list of sizes for each chunk. :type: list(int) :param int dim: Dimension along which to split. Defaults to -1. """ def __init__(self, sections, dim): assert isinstance(dim, int) and dim < 0 assert isinstance(sections, list) assert all(isinstance(size, int) for size in sections) self.event_dim = -dim self.sections = sections
[docs] def __call__(self, name, fn, obs): assert fn.event_dim >= self.event_dim assert obs is None, "SplitReparam does not support observe statements" # Draw independent parts. dim = fn.event_dim - self.event_dim left_shape = fn.event_shape[:dim] right_shape = fn.event_shape[1 + dim:] parts = [] for i, size in enumerate(self.sections): event_shape = left_shape + (size,) + right_shape parts.append(pyro.sample( "{}_split_{}".format(name, i), dist.ImproperUniform(, fn.batch_shape, event_shape))) value =, dim=-self.event_dim) # Combine parts. log_prob = fn.log_prob(value) new_fn = dist.Delta(value, event_dim=fn.event_dim, log_density=log_prob) return new_fn, value