Source code for pyro.infer.trace_tail_adaptive_elbo

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import warnings

import torch

from pyro.infer.trace_elbo import Trace_ELBO
from pyro.infer.util import check_fully_reparametrized, is_validation_enabled


[docs]class TraceTailAdaptive_ELBO(Trace_ELBO): """ Interface for Stochastic Variational Inference with an adaptive f-divergence as described in ref. [1]. Users should specify `num_particles` > 1 and `vectorize_particles==True`. The argument `tail_adaptive_beta` can be specified to modify how the adaptive f-divergence is constructed. See reference for details. Note that this interface does not support computing the varational objective itself; rather it only supports computing gradients of the variational objective. Consequently, one might want to use another SVI interface (e.g. `RenyiELBO`) in order to monitor convergence. Note that this interface only supports models in which all the latent variables are fully reparameterized. It also does not support data subsampling. References [1] "Variational Inference with Tail-adaptive f-Divergence", Dilin Wang, Hao Liu, Qiang Liu, NeurIPS 2018 https://papers.nips.cc/paper/7816-variational-inference-with-tail-adaptive-f-divergence """
[docs] def loss(self, model, guide, *args, **kwargs): """ It is not necessary to estimate the tail-adaptive f-divergence itself in order to compute the corresponding gradients. Consequently the loss method is left unimplemented. """ raise NotImplementedError( "Loss method for TraceTailAdaptive_ELBO not implemented" )
def _differentiable_loss_particle(self, model_trace, guide_trace): if not self.vectorize_particles: raise NotImplementedError( "TraceTailAdaptive_ELBO only implemented for vectorize_particles==True" ) if self.num_particles == 1: warnings.warn( "For num_particles==1 TraceTailAdaptive_ELBO uses the same loss function as Trace_ELBO. " + "Increase num_particles to get an adaptive f-divergence." ) log_p, log_q = 0, 0 for name, site in model_trace.nodes.items(): if site["type"] == "sample": site_log_p = site["log_prob"].reshape(self.num_particles, -1).sum(-1) log_p = log_p + site_log_p for name, site in guide_trace.nodes.items(): if site["type"] == "sample": site_log_q = site["log_prob"].reshape(self.num_particles, -1).sum(-1) log_q = log_q + site_log_q if is_validation_enabled(): check_fully_reparametrized(site) # rank the particles according to p/q log_pq = log_p - log_q rank = torch.argsort(log_pq, descending=False) rank = torch.index_select( torch.arange(self.num_particles, device=log_pq.device) + 1, -1, rank ).type_as(log_pq) # compute the particle-specific weights used to construct the surrogate loss gamma = torch.pow(rank, self.tail_adaptive_beta).detach() surrogate_loss = -(log_pq * gamma).sum() / gamma.sum() # we do not compute the loss, so return `inf` return float("inf"), surrogate_loss