Source code for pyro.distributions.nanmasked

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import torch

from .torch import MultivariateNormal, Normal


[docs]class NanMaskedNormal(Normal): """ Wrapper around :class:`~pyro.distributions.Normal` to allow partially observed data as specified by NAN elements in :meth:`log_prob`; the ``log_prob`` of these elements will be zero. This is useful for likelihoods with missing data. Example:: from math import nan data = torch.tensor([0.5, 0.1, nan, 0.9]) with pyro.plate("data", len(data)): pyro.sample("obs", NanMaskedNormal(0, 1), obs=data) """
[docs] def log_prob(self, value: torch.Tensor) -> torch.Tensor: ok = value.isfinite() if ok.all(): return super().log_prob(value) # Broadcast all tensors. value, ok, loc, scale = torch.broadcast_tensors(value, ok, self.loc, self.scale) result = value.new_zeros(value.shape) # Evaluate ok elements. if ok.any(): marginal = Normal(loc[ok], scale[ok], validate_args=False) result[ok] = marginal.log_prob(value[ok]) return result
[docs]class NanMaskedMultivariateNormal(MultivariateNormal): """ Wrapper around :class:`~pyro.distributions.MultivariateNormal` to allow partially observed data as specified by NAN elements in the argument to :meth:`log_prob`. The ``log_prob`` of these events will marginalize over the NAN elements. This is useful for likelihoods with missing data. Example:: from math import nan data = torch.tensor([ [0.1, 0.2, 3.4], [0.5, 0.1, nan], [0.6, nan, nan], [nan, 0.5, nan], [nan, nan, nan], ]) with pyro.plate("data", len(data)): pyro.sample( "obs", NanMaskedMultivariateNormal(torch.zeros(3), torch.eye(3)), obs=data, ) """
[docs] def log_prob(self, value: torch.Tensor) -> torch.Tensor: ok = value.isfinite() if ok.all(): return super().log_prob(value) # Broadcast all tensors. This might waste some computation by eagerly # broadcasting, but the optimal implementation is quite complex. value, ok, loc = torch.broadcast_tensors(value, ok, self.loc) cov = self.covariance_matrix.expand(loc.shape + loc.shape[-1:]) # Flatten. result_shape = value.shape[:-1] n = result_shape.numel() p = value.shape[-1] value = value.reshape(n, p) ok = ok.reshape(n, p) loc = loc.reshape(n, p) cov = cov.reshape(n, p, p) result = value.new_zeros(n) # Evaluate ok elements. for pattern in sorted(set(map(tuple, ok.tolist()))): if not any(pattern): continue # Marginalize out NAN elements. col_mask = torch.tensor(pattern) row_mask = (ok == col_mask).all(-1) ok_value = value[row_mask][:, col_mask] ok_loc = loc[row_mask][:, col_mask] ok_cov = cov[row_mask][:, col_mask][:, :, col_mask] marginal = MultivariateNormal(ok_loc, ok_cov, validate_args=False) result[row_mask] = marginal.log_prob(ok_value) # Unflatten. return result.reshape(result_shape)