Source code for pyro.infer.rws

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import math

import torch

import pyro
import pyro.poutine as poutine
from pyro.infer.elbo import ELBO
from pyro.infer.enum import get_importance_trace
from pyro.infer.util import is_validation_enabled
from pyro.poutine.util import prune_subsample_sites
from pyro.util import check_if_enumerated, check_model_guide_match, warn_if_nan


[docs]class ReweightedWakeSleep(ELBO): r""" An implementation of Reweighted Wake Sleep following reference [1]. .. note:: Sampling and log_prob evaluation asymptotic complexity: 1) Using wake-theta and/or wake-phi O(`num_particles`) samples from guide, O(`num_particles`) `log_prob` evaluations of model and guide 2) Using sleep-phi O(`num_sleep_particles`) samples from model, O(`num_sleep_particles`) `log_prob` evaluations of guide if 1) and 2) are combined, O(`num_particles`) samples from the guide, O(`num_sleep_particles`) from the model, O(`num_particles` + `num_sleep_particles`) `log_prob` evaluations of the guide, and O(`num_particles`) evaluations of the model .. note:: This is particularly useful for models with stochastic branching, as described in [2]. .. note:: This returns _two_ losses, one each for (a) the model parameters (`theta`), computed using the `iwae` objective, and (b) the guide parameters (`phi`), computed using (a combination of) the `csis` objective and a self-normalized importance-sampled version of the `csis` objective. .. note:: In order to enable computing the sleep-phi terms, the guide program must have its observations explicitly passed in through the keyworded argument `observations`. Where the value of the observations is unknown during definition, such as for amortized variational inference, it may be given a default argument as `observations=None`, and the correct value supplied during learning through `svi.step(observations=...)`. .. warning:: Mini-batch training is not supported yet. :param int num_particles: The number of particles/samples used to form the objective (gradient) estimator. Default is 2. :param insomnia: The scaling between the wake-phi and sleep-phi terms. Default is 1.0 [wake-phi] :param bool model_has_params: Indicate if model has learnable params. Useful in avoiding extra computation when running in pure sleep mode [csis]. Default is True. :param int num_sleep_particles: The number of particles used to form the sleep-phi estimator. Matches `num_particles` by default. :param bool vectorize_particles: Whether the traces should be vectorised across `num_particles`. Default is True. :param int max_plate_nesting: Bound on max number of nested :func:`pyro.plate` contexts. Default is infinity. :param bool strict_enumeration_warning: Whether to warn about possible misuse of enumeration, i.e. that :class:`~pyro.infer.traceenum_elbo.TraceEnum_ELBO` is used iff there are enumerated sample sites. References: [1] `Reweighted Wake-Sleep`, Jörg Bornschein, Yoshua Bengio [2] `Revisiting Reweighted Wake-Sleep for Models with Stochastic Control Flow`, Tuan Anh Le, Adam R. Kosiorek, N. Siddharth, Yee Whye Teh, Frank Wood """ def __init__( self, num_particles=2, insomnia=1.0, model_has_params=True, num_sleep_particles=None, vectorize_particles=True, max_plate_nesting=float("inf"), strict_enumeration_warning=True, ): # force K > 1 otherwise SNIS not possible assert ( num_particles > 1 ), "Reweighted Wake Sleep needs to be run with more than one particle" super().__init__( num_particles=num_particles, max_plate_nesting=max_plate_nesting, vectorize_particles=vectorize_particles, strict_enumeration_warning=strict_enumeration_warning, ) self.insomnia = insomnia self.model_has_params = model_has_params self.num_sleep_particles = ( num_particles if num_sleep_particles is None else num_sleep_particles ) assert insomnia >= 0 and insomnia <= 1, "insomnia should be in [0, 1]" def _get_trace(self, model, guide, args, kwargs): """ Returns a single trace from the guide, and the model that is run against it. """ model_trace, guide_trace = get_importance_trace( "flat", self.max_plate_nesting, model, guide, args, kwargs, detach=True ) if is_validation_enabled(): check_if_enumerated(guide_trace) return model_trace, guide_trace def _loss(self, model, guide, args, kwargs): """ :returns: returns model loss and guide loss :rtype: float, float Computes the re-weighted wake-sleep estimators for the model (wake-theta) and the guide (insomnia * wake-phi + (1 - insomnia) * sleep-phi). Performs backward as appropriate on both, over the specified number of particles. """ wake_theta_loss = torch.tensor(100.0) if self.model_has_params or self.insomnia > 0.0: # compute quantities for wake theta and wake phi log_joints = [] log_qs = [] for model_trace, guide_trace in self._get_traces( model, guide, args, kwargs ): log_joint = 0.0 log_q = 0.0 for _, site in model_trace.nodes.items(): if site["type"] == "sample": if self.vectorize_particles: log_p_site = ( site["log_prob"].reshape(self.num_particles, -1).sum(-1) ) else: log_p_site = site["log_prob_sum"] log_joint = log_joint + log_p_site for _, site in guide_trace.nodes.items(): if site["type"] == "sample": if self.vectorize_particles: log_q_site = ( site["log_prob"].reshape(self.num_particles, -1).sum(-1) ) else: log_q_site = site["log_prob_sum"] log_q = log_q + log_q_site log_joints.append(log_joint) log_qs.append(log_q) log_joints = ( log_joints[0] if self.vectorize_particles else torch.stack(log_joints) ) log_qs = log_qs[0] if self.vectorize_particles else torch.stack(log_qs) log_weights = log_joints - log_qs.detach() # compute wake theta loss log_sum_weight = torch.logsumexp(log_weights, dim=0) wake_theta_loss = -(log_sum_weight - math.log(self.num_particles)).sum() warn_if_nan(wake_theta_loss, "wake theta loss") if self.insomnia > 0: # compute wake phi loss normalised_weights = (log_weights - log_sum_weight).exp().detach() wake_phi_loss = -(normalised_weights * log_qs).sum() warn_if_nan(wake_phi_loss, "wake phi loss") if self.insomnia < 1: # compute sleep phi loss _model = pyro.poutine.uncondition(model) _guide = guide _log_q = 0.0 if self.vectorize_particles: if self.max_plate_nesting == float("inf"): self._guess_max_plate_nesting(_model, _guide, args, kwargs) _model = self._vectorized_num_sleep_particles(_model) _guide = self._vectorized_num_sleep_particles(guide) for _ in range(1 if self.vectorize_particles else self.num_sleep_particles): _model_trace = poutine.trace(_model).get_trace(*args, **kwargs) _model_trace.detach_() _guide_trace = self._get_matched_trace( _model_trace, _guide, args, kwargs ) _log_q += _guide_trace.log_prob_sum() sleep_phi_loss = -_log_q / self.num_sleep_particles warn_if_nan(sleep_phi_loss, "sleep phi loss") # compute phi loss phi_loss = ( sleep_phi_loss if self.insomnia == 0 else wake_phi_loss if self.insomnia == 1 else self.insomnia * wake_phi_loss + (1.0 - self.insomnia) * sleep_phi_loss ) return wake_theta_loss, phi_loss
[docs] def loss(self, model, guide, *args, **kwargs): """ :returns: returns model loss and guide loss :rtype: float, float Computes the re-weighted wake-sleep estimators for the model (wake-theta) and the guide (insomnia * wake-phi + (1 - insomnia) * sleep-phi). """ with torch.no_grad(): wake_theta_loss, phi_loss = self._loss(model, guide, args, kwargs) return wake_theta_loss, phi_loss
[docs] def loss_and_grads(self, model, guide, *args, **kwargs): """ :returns: returns model loss and guide loss :rtype: float Computes the RWS estimators for the model (wake-theta) and the guide (wake-phi). Performs backward as appropriate on both, using num_particle many samples/particles. """ wake_theta_loss, phi_loss = self._loss(model, guide, args, kwargs) # convenience addition to ensure easier gradients without requiring `retain_graph=True` (wake_theta_loss + phi_loss).backward() return wake_theta_loss.detach().item(), phi_loss.detach().item()
def _vectorized_num_sleep_particles(self, fn): """ Copy of `_vectorised_num_particles` that uses `num_sleep_particles`. """ def wrapped_fn(*args, **kwargs): if self.num_sleep_particles == 1: return fn(*args, **kwargs) with pyro.plate( "num_sleep_particles_vectorized", self.num_sleep_particles, dim=-self.max_plate_nesting, ): return fn(*args, **kwargs) return wrapped_fn @staticmethod def _get_matched_trace(model_trace, guide, args, kwargs): kwargs["observations"] = {} for node in model_trace.stochastic_nodes + model_trace.observation_nodes: if "was_observed" in model_trace.nodes[node]["infer"]: model_trace.nodes[node]["is_observed"] = True kwargs["observations"][node] = model_trace.nodes[node]["value"] guide_trace = poutine.trace(poutine.replay(guide, model_trace)).get_trace( *args, **kwargs ) check_model_guide_match(model_trace, guide_trace) guide_trace = prune_subsample_sites(guide_trace) return guide_trace