Source code for pyro.infer.renyi_elbo

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import math
import warnings

import torch

from pyro.distributions.util import is_identically_zero
from pyro.infer.elbo import ELBO
from pyro.infer.enum import get_importance_trace
from pyro.infer.util import get_dependent_plate_dims, is_validation_enabled, torch_sum
from pyro.util import check_if_enumerated, warn_if_nan


[docs]class RenyiELBO(ELBO): r""" An implementation of Renyi's :math:`\alpha`-divergence variational inference following reference [1]. In order for the objective to be a strict lower bound, we require :math:`\alpha \ge 0`. Note, however, that according to reference [1], depending on the dataset :math:`\alpha < 0` might give better results. In the special case :math:`\alpha = 0`, the objective function is that of the important weighted autoencoder derived in reference [2]. .. note:: Setting :math:`\alpha < 1` gives a better bound than the usual ELBO. For :math:`\alpha = 1`, it is better to use :class:`~pyro.infer.trace_elbo.Trace_ELBO` class because it helps reduce variances of gradient estimations. :param float alpha: The order of :math:`\alpha`-divergence. Here :math:`\alpha \neq 1`. Default is 0. :param num_particles: The number of particles/samples used to form the objective (gradient) estimator. Default is 2. :param int max_plate_nesting: Bound on max number of nested :func:`pyro.plate` contexts. Default is infinity. :param bool strict_enumeration_warning: Whether to warn about possible misuse of enumeration, i.e. that :class:`~pyro.infer.traceenum_elbo.TraceEnum_ELBO` is used iff there are enumerated sample sites. References: [1] `Renyi Divergence Variational Inference`, Yingzhen Li, Richard E. Turner [2] `Importance Weighted Autoencoders`, Yuri Burda, Roger Grosse, Ruslan Salakhutdinov """ def __init__( self, alpha=0, num_particles=2, max_plate_nesting=float("inf"), max_iarange_nesting=None, # DEPRECATED vectorize_particles=False, strict_enumeration_warning=True, ): if max_iarange_nesting is not None: warnings.warn( "max_iarange_nesting is deprecated; use max_plate_nesting instead", DeprecationWarning, ) max_plate_nesting = max_iarange_nesting if alpha == 1: raise ValueError( "The order alpha should not be equal to 1. Please use Trace_ELBO class" "for the case alpha = 1." ) self.alpha = alpha super().__init__( num_particles=num_particles, max_plate_nesting=max_plate_nesting, vectorize_particles=vectorize_particles, strict_enumeration_warning=strict_enumeration_warning, ) def _get_trace(self, model, guide, args, kwargs): """ Returns a single trace from the guide, and the model that is run against it. """ model_trace, guide_trace = get_importance_trace( "flat", self.max_plate_nesting, model, guide, args, kwargs ) if is_validation_enabled(): check_if_enumerated(guide_trace) return model_trace, guide_trace
[docs] @torch.no_grad() def loss(self, model, guide, *args, **kwargs): """ :returns: returns an estimate of the ELBO :rtype: float Evaluates the ELBO with an estimator that uses num_particles many samples/particles. """ elbo_particles = [] is_vectorized = self.vectorize_particles and self.num_particles > 1 # grab a vectorized trace from the generator for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs): elbo_particle = 0.0 sum_dims = get_dependent_plate_dims(model_trace.nodes.values()) # compute elbo for name, site in model_trace.nodes.items(): if site["type"] == "sample": log_prob_sum = torch_sum(site["log_prob"], sum_dims) elbo_particle = elbo_particle + log_prob_sum for name, site in guide_trace.nodes.items(): if site["type"] == "sample": log_prob, score_function_term, entropy_term = site["score_parts"] log_prob_sum = torch_sum(site["log_prob"], sum_dims) elbo_particle = elbo_particle - log_prob_sum elbo_particles.append(elbo_particle) if is_vectorized: elbo_particles = elbo_particles[0] else: elbo_particles = torch.stack(elbo_particles) log_weights = (1.0 - self.alpha) * elbo_particles log_mean_weight = torch.logsumexp(log_weights, dim=0) - math.log( self.num_particles ) elbo = log_mean_weight.sum().item() / (1.0 - self.alpha) loss = -elbo warn_if_nan(loss, "loss") return loss
[docs] def loss_and_grads(self, model, guide, *args, **kwargs): """ :returns: returns an estimate of the ELBO :rtype: float Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator. Performs backward on the latter. Num_particle many samples are used to form the estimators. """ elbo_particles = [] surrogate_elbo_particles = [] is_vectorized = self.vectorize_particles and self.num_particles > 1 tensor_holder = None # grab a vectorized trace from the generator for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs): elbo_particle = 0 surrogate_elbo_particle = 0 sum_dims = get_dependent_plate_dims(model_trace.nodes.values()) # compute elbo and surrogate elbo for name, site in model_trace.nodes.items(): if site["type"] == "sample": log_prob_sum = torch_sum(site["log_prob"], sum_dims) elbo_particle = elbo_particle + log_prob_sum.detach() surrogate_elbo_particle = surrogate_elbo_particle + log_prob_sum for name, site in guide_trace.nodes.items(): if site["type"] == "sample": log_prob, score_function_term, entropy_term = site["score_parts"] log_prob_sum = torch_sum(site["log_prob"], sum_dims) elbo_particle = elbo_particle - log_prob_sum.detach() if not is_identically_zero(entropy_term): surrogate_elbo_particle = surrogate_elbo_particle - log_prob_sum if not is_identically_zero(score_function_term): # link to the issue: https://github.com/pyro-ppl/pyro/issues/1222 raise NotImplementedError if not is_identically_zero(score_function_term): surrogate_elbo_particle = ( surrogate_elbo_particle + (self.alpha / (1.0 - self.alpha)) * log_prob_sum ) if is_identically_zero(elbo_particle): if tensor_holder is not None: elbo_particle = torch.zeros_like(tensor_holder) surrogate_elbo_particle = torch.zeros_like(tensor_holder) else: # elbo_particle is not None if tensor_holder is None: tensor_holder = torch.zeros_like(elbo_particle) # change types of previous `elbo_particle`s for i in range(len(elbo_particles)): elbo_particles[i] = torch.zeros_like(tensor_holder) surrogate_elbo_particles[i] = torch.zeros_like(tensor_holder) elbo_particles.append(elbo_particle) surrogate_elbo_particles.append(surrogate_elbo_particle) if tensor_holder is None: return 0.0 if is_vectorized: elbo_particles = elbo_particles[0] surrogate_elbo_particles = surrogate_elbo_particles[0] else: elbo_particles = torch.stack(elbo_particles) surrogate_elbo_particles = torch.stack(surrogate_elbo_particles) log_weights = (1.0 - self.alpha) * elbo_particles log_mean_weight = torch.logsumexp(log_weights, dim=0, keepdim=True) - math.log( self.num_particles ) elbo = log_mean_weight.sum().item() / (1.0 - self.alpha) # collect parameters to train from model and guide trainable_params = any( site["type"] == "param" for trace in (model_trace, guide_trace) for site in trace.nodes.values() ) if trainable_params and getattr( surrogate_elbo_particles, "requires_grad", False ): normalized_weights = (log_weights - log_mean_weight).exp() surrogate_elbo = ( normalized_weights * surrogate_elbo_particles ).sum() / self.num_particles surrogate_loss = -surrogate_elbo surrogate_loss.backward() loss = -elbo warn_if_nan(loss, "loss") return loss