Source code for

from __future__ import absolute_import, division, print_function

from collections import OrderedDict

import torch.nn as nn
from torch.distributions import biject_to, constraints, transform_to
from torch.nn import Parameter

import pyro
import pyro.distributions as dist
from pyro.contrib import autoname
from pyro.distributions.util import eye_like

def _get_independent_support(dist_instance):
    # XXX Should we treat the case dist_instance is Independent(Independent(Normal))?
    if isinstance(dist_instance, dist.Independent):

[docs]class Parameterized(nn.Module): """ A wrapper of :class:`torch.nn.Module` whose parameters can be set constraints, set priors. Under the hood, we move parameters to a buffer store and create "root" parameters which are used to generate that parameter's value. For example, if we set a contraint to a parameter, an "unconstrained" parameter will be created, and the constrained value will be transformed from that "unconstrained" parameter. By default, when we set a prior to a parameter, an auto Delta guide will be created. We can use the method :meth:`autoguide` to setup other auto guides. To fix a parameter to a specific value, it is enough to turn off its "root" parameters' ``requires_grad`` flags. Example:: >>> class Linear(Parameterized): ... def __init__(self, a, b): ... super(Linear, self).__init__() ... self.a = Parameter(a) ... self.b = Parameter(b) ... ... def forward(self, x): ... return self.a * x + self.b ... >>> linear = Linear(torch.tensor(1.), torch.tensor(0.)) >>> linear.set_constraint("a", constraints.positive) >>> linear.set_prior("b", dist.Normal(0, 1)) >>> linear.autoguide("b", dist.Normal) >>> assert "a_unconstrained" in dict(linear.named_parameters()) >>> assert "b_loc" in dict(linear.named_parameters()) >>> assert "b_scale_unconstrained" in dict(linear.named_parameters()) >>> assert "a" in dict(linear.named_buffers()) >>> assert "b" in dict(linear.named_buffers()) >>> assert "b_scale" in dict(linear.named_buffers()) Note that by default, data of a parameter is a float :class:`torch.Tensor` (unless we use :func:`torch.set_default_tensor_type` to change default tensor type). To cast these parameters to a correct data type or GPU device, we can call methods such as :meth:`~torch.nn.Module.double` or :meth:`~torch.nn.Module.cuda`. See :class:`torch.nn.Module` for more information. """ def __init__(self): super(Parameterized, self).__init__() self._constraints = OrderedDict() self._priors = OrderedDict() self._guides = OrderedDict() self._mode = None
[docs] def set_constraint(self, name, constraint): """ Sets the constraint of an existing parameter. :param str name: Name of the parameter. :param ~constraints.Constraint constraint: A PyTorch constraint. See :mod:`torch.distributions.constraints` for a list of constraints. """ if constraint in [constraints.real, constraints.real_vector]: if name in self._constraints: # delete previous constraints self._constraints.pop(name, None) self._parameters.pop("{}_unconstrained".format(name)) if name not in self._priors: # no prior -> no guide # so we can move param back from buffer p = Parameter(self._buffers.pop(name).detach()) self.register_parameter(name, p) return if name in self._priors: raise ValueError("Parameter {} already has a prior. Can not set a constraint for it." .format(name)) if name in self._parameters: p = self._parameters.pop(name) elif name in self._buffers: p = self._buffers[name] else: raise ValueError("There is no parameter with name: {}".format(name)) p_unconstrained = Parameter(transform_to(constraint).inv(p).detach()) self.register_parameter("{}_unconstrained".format(name), p_unconstrained) # due to precision issue, we might get f(f^-1(x)) != x # so it is necessary to transform back p = transform_to(constraint)(p_unconstrained) self.register_buffer(name, p.detach()) self._constraints[name] = constraint
[docs] def set_prior(self, name, prior): """ Sets the constraint of an existing parameter. :param str name: Name of the parameter. :param ~pyro.distributions.distribution.Distribution prior: A Pyro prior distribution. """ if name in self._parameters: # move param to _buffers p = self._parameters.pop(name) self.register_buffer(name, p) elif name not in self._buffers: raise ValueError("There is no parameter with name: {}".format(name)) self._priors[name] = prior # remove the constraint and its unconstrained parameter self.set_constraint(name, constraints.real) self.autoguide(name, dist.Delta)
[docs] def autoguide(self, name, dist_constructor): """ Sets an autoguide for an existing parameter with name ``name`` (mimic the behavior of module :mod:`pyro.contrib.autoguide`). ..note:: `dist_constructor` should be one of :class:`~pyro.distributions.Delta`, :class:`~pyro.distributions.Normal`, and :class:`~pyro.distributions.MultivariateNormal`. More distribution constructor will be supported in the future if needed. :param str name: Name of the parameter. :param dist_constructor: A :class:`~pyro.distributions.distribution.Distribution` constructor. """ if name not in self._priors: raise ValueError("There is no prior for parameter: {}".format(name)) if dist_constructor not in [dist.Delta, dist.Normal, dist.MultivariateNormal]: raise NotImplementedError("Unsupported distribution type: {}" .format(dist_constructor)) if name in self._guides: # delete previous guide's parameters dist_args = self._guides[name][1] for arg in dist_args: arg_name = "{}_{}".format(name, arg) if arg_name in self._constraints: # delete its unconstrained parameter self.set_constraint(arg_name, constraints.real) delattr(self, arg_name) # TODO: create a new argument `autoguide_args` to store other args for other # constructors. For example, in LowRankMVN, we need argument `rank`. p = self._buffers[name] if dist_constructor is dist.Delta: p_map = Parameter(p.detach()) self.register_parameter("{}_map".format(name), p_map) self.set_constraint("{}_map".format(name), _get_independent_support(self._priors[name])) dist_args = {"map"} elif dist_constructor is dist.Normal: loc = Parameter(biject_to(self._priors[name].support).inv(p).detach()) scale = Parameter(loc.new_ones(loc.shape)) self.register_parameter("{}_loc".format(name), loc) self.register_parameter("{}_scale".format(name), scale) dist_args = {"loc", "scale"} elif dist_constructor is dist.MultivariateNormal: loc = Parameter(biject_to(self._priors[name].support).inv(p).detach()) identity = eye_like(loc, loc.size(-1)) scale_tril = Parameter(identity.repeat(loc.shape[:-1] + (1, 1))) self.register_parameter("{}_loc".format(name), loc) self.register_parameter("{}_scale_tril".format(name), scale_tril) dist_args = {"loc", "scale_tril"} else: raise NotImplementedError if dist_constructor is not dist.Delta: # each arg has a constraint, so we set constraints for them for arg in dist_args: self.set_constraint("{}_{}".format(name, arg), dist_constructor.arg_constraints[arg]) self._guides[name] = (dist_constructor, dist_args)
[docs] def set_mode(self, mode): """ Sets ``mode`` of this object to be able to use its parameters in stochastic functions. If ``mode="model"``, a parameter will get its value from its prior. If ``mode="guide"``, the value will be drawn from its guide. ..note:: This method automatically sets ``mode`` for submodules which belong to :class:`Parameterized` class. :param str mode: Either "model" or "guide". """ with autoname.name_count(): for module in self.modules(): if isinstance(module, Parameterized): module.mode = mode
@property def mode(self): return self._mode @mode.setter def mode(self, mode): self._mode = mode # We should get buffer values for constrained params first # otherwise, autoguide will use the old buffer for `scale` or `scale_tril` for name in self._constraints: if name not in self._priors: self._register_param(name) for name in self._priors: self._register_param(name) def _sample_from_guide(self, name): dist_constructor, dist_args = self._guides[name] if dist_constructor is dist.Delta: p_map = getattr(self, "{}_map".format(name)) return pyro.sample(name, dist.Delta(p_map, event_dim=p_map.dim())) # create guide dist_args = {arg: getattr(self, "{}_{}".format(name, arg)) for arg in dist_args} guide = dist_constructor(**dist_args) # no need to do transforms when support is real (for mean field ELBO) if _get_independent_support(self._priors[name]) is constraints.real: return pyro.sample(name, guide.to_event()) # otherwise, we do inference in unconstrained space and transform the value # back to original space # TODO: move this logic to contrib.autoguide or somewhere else unconstrained_value = pyro.sample("{}_latent".format(name), guide.to_event(), infer={"is_auxiliary": True}) transform = biject_to(self._priors[name].support) value = transform(unconstrained_value) log_density = transform.inv.log_abs_det_jacobian(value, unconstrained_value) return pyro.sample(name, dist.Delta(value, log_density.sum(), event_dim=value.dim())) def _register_param(self, name): """ In "model" mode, lifts the parameter with name ``name`` to a random sample using a predefined prior (from :meth:`set_prior` method). In "guide" mode, we use the guide generated from :meth:`autoguide`. :param str name: Name of the parameter. """ if name in self._priors: with autoname.scope(prefix=self._get_name()): if self.mode == "model": p = pyro.sample(name, self._priors[name]) else: p = self._sample_from_guide(name) elif name in self._constraints: p_unconstrained = self._parameters["{}_unconstrained".format(name)] p = transform_to(self._constraints[name])(p_unconstrained) self.register_buffer(name, p)