# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
"""
This module implements the Causal Effect Variational Autoencoder [1], which
demonstrates a number of innovations including:
- a generative model for causal effect inference with hidden confounders;
- a model and guide with twin neural nets to allow imbalanced treatment; and
- a custom training loss that includes both ELBO terms and extra terms needed
to train the guide to be able to answer counterfactual queries.
The main interface is the :class:`CEVAE` class, but users may customize by
using components :class:`Model`, :class:`Guide`,
:class:`TraceCausalEffect_ELBO` and utilities.
**References**
[1] C. Louizos, U. Shalit, J. Mooij, D. Sontag, R. Zemel, M. Welling (2017).
| Causal Effect Inference with Deep Latent-Variable Models.
| http://papers.nips.cc/paper/7223-causal-effect-inference-with-deep-latent-variable-models.pdf
| https://github.com/AMLab-Amsterdam/CEVAE
"""
import logging
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import pyro
import pyro.distributions as dist
from pyro import poutine
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.util import torch_item
from pyro.nn import PyroModule
from pyro.optim import ClippedAdam
from pyro.util import torch_isnan
logger = logging.getLogger(__name__)
[docs]class FullyConnected(nn.Sequential):
"""
Fully connected multi-layer network with ELU activations.
"""
def __init__(self, sizes, final_activation=None):
layers = []
for in_size, out_size in zip(sizes, sizes[1:]):
layers.append(nn.Linear(in_size, out_size))
layers.append(nn.ELU())
layers.pop(-1)
if final_activation is not None:
layers.append(final_activation)
super().__init__(*layers)
[docs] def append(self, layer):
assert isinstance(layer, nn.Module)
self.add_module(str(len(self)), layer)
[docs]class DistributionNet(nn.Module):
"""
Base class for distribution nets.
"""
[docs] @staticmethod
def get_class(dtype):
"""
Get a subclass by a prefix of its name, e.g.::
assert DistributionNet.get_class("bernoulli") is BernoulliNet
"""
for cls in DistributionNet.__subclasses__():
if cls.__name__.lower() == dtype + "net":
return cls
raise ValueError("dtype not supported: {}".format(dtype))
[docs]class BernoulliNet(DistributionNet):
"""
:class:`FullyConnected` network outputting a single ``logits`` value.
This is used to represent a conditional probability distribution of a
single Bernoulli random variable conditioned on a ``sizes[0]``-sized real
value, for example::
net = BernoulliNet([3, 4])
z = torch.randn(3)
logits, = net(z)
t = net.make_dist(logits).sample()
"""
def __init__(self, sizes):
assert len(sizes) >= 1
super().__init__()
self.fc = FullyConnected(sizes + [1])
[docs] def forward(self, x):
logits = self.fc(x).squeeze(-1).clamp(min=-10, max=10)
return (logits,)
[docs] @staticmethod
def make_dist(logits):
return dist.Bernoulli(logits=logits)
[docs]class ExponentialNet(DistributionNet):
"""
:class:`FullyConnected` network outputting a constrained ``rate``.
This is used to represent a conditional probability distribution of a
single Normal random variable conditioned on a ``sizes[0]``-size real
value, for example::
net = ExponentialNet([3, 4])
x = torch.randn(3)
rate, = net(x)
y = net.make_dist(rate).sample()
"""
def __init__(self, sizes):
assert len(sizes) >= 1
super().__init__()
self.fc = FullyConnected(sizes + [1])
[docs] def forward(self, x):
scale = nn.functional.softplus(self.fc(x).squeeze(-1)).clamp(min=1e-3, max=1e6)
rate = scale.reciprocal()
return (rate,)
[docs] @staticmethod
def make_dist(rate):
return dist.Exponential(rate)
[docs]class LaplaceNet(DistributionNet):
"""
:class:`FullyConnected` network outputting a constrained ``loc,scale``
pair.
This is used to represent a conditional probability distribution of a
single Laplace random variable conditioned on a ``sizes[0]``-size real
value, for example::
net = LaplaceNet([3, 4])
x = torch.randn(3)
loc, scale = net(x)
y = net.make_dist(loc, scale).sample()
"""
def __init__(self, sizes):
assert len(sizes) >= 1
super().__init__()
self.fc = FullyConnected(sizes + [2])
[docs] def forward(self, x):
loc_scale = self.fc(x)
loc = loc_scale[..., 0].clamp(min=-1e6, max=1e6)
scale = nn.functional.softplus(loc_scale[..., 1]).clamp(min=1e-3, max=1e6)
return loc, scale
[docs] @staticmethod
def make_dist(loc, scale):
return dist.Laplace(loc, scale)
[docs]class NormalNet(DistributionNet):
"""
:class:`FullyConnected` network outputting a constrained ``loc,scale``
pair.
This is used to represent a conditional probability distribution of a
single Normal random variable conditioned on a ``sizes[0]``-size real
value, for example::
net = NormalNet([3, 4])
x = torch.randn(3)
loc, scale = net(x)
y = net.make_dist(loc, scale).sample()
"""
def __init__(self, sizes):
assert len(sizes) >= 1
super().__init__()
self.fc = FullyConnected(sizes + [2])
[docs] def forward(self, x):
loc_scale = self.fc(x)
loc = loc_scale[..., 0].clamp(min=-1e6, max=1e6)
scale = nn.functional.softplus(loc_scale[..., 1]).clamp(min=1e-3, max=1e6)
return loc, scale
[docs] @staticmethod
def make_dist(loc, scale):
return dist.Normal(loc, scale)
[docs]class StudentTNet(DistributionNet):
"""
:class:`FullyConnected` network outputting a constrained ``df,loc,scale``
triple, with shared ``df > 1``.
This is used to represent a conditional probability distribution of a
single Student's t random variable conditioned on a ``sizes[0]``-size real
value, for example::
net = StudentTNet([3, 4])
x = torch.randn(3)
df, loc, scale = net(x)
y = net.make_dist(df, loc, scale).sample()
"""
def __init__(self, sizes):
assert len(sizes) >= 1
super().__init__()
self.fc = FullyConnected(sizes + [2])
self.df_unconstrained = nn.Parameter(torch.tensor(0.0))
[docs] def forward(self, x):
loc_scale = self.fc(x)
loc = loc_scale[..., 0].clamp(min=-1e6, max=1e6)
scale = nn.functional.softplus(loc_scale[..., 1]).clamp(min=1e-3, max=1e6)
df = nn.functional.softplus(self.df_unconstrained).add(1).expand_as(loc)
return df, loc, scale
[docs] @staticmethod
def make_dist(df, loc, scale):
return dist.StudentT(df, loc, scale)
[docs]class DiagNormalNet(nn.Module):
"""
:class:`FullyConnected` network outputting a constrained ``loc,scale``
pair.
This is used to represent a conditional probability distribution of a
``sizes[-1]``-sized diagonal Normal random variable conditioned on a
``sizes[0]``-size real value, for example::
net = DiagNormalNet([3, 4, 5])
z = torch.randn(3)
loc, scale = net(z)
x = dist.Normal(loc, scale).sample()
This is intended for the latent ``z`` distribution and the prewhitened
``x`` features, and conservatively clips ``loc`` and ``scale`` values.
"""
def __init__(self, sizes):
assert len(sizes) >= 2
self.dim = sizes[-1]
super().__init__()
self.fc = FullyConnected(sizes[:-1] + [self.dim * 2])
[docs] def forward(self, x):
loc_scale = self.fc(x)
loc = loc_scale[..., : self.dim].clamp(min=-1e2, max=1e2)
scale = (
nn.functional.softplus(loc_scale[..., self.dim :]).add(1e-3).clamp(max=1e2)
)
return loc, scale
class PreWhitener(nn.Module):
"""
Data pre-whitener.
"""
def __init__(self, data):
super().__init__()
with torch.no_grad():
loc = data.mean(0)
scale = data.std(0, unbiased=False)
scale[~(scale > 0)] = 1.0
self.register_buffer("loc", loc)
self.register_buffer("inv_scale", scale.reciprocal())
def forward(self, data):
return (data - self.loc) * self.inv_scale
[docs]class Model(PyroModule):
"""
Generative model for a causal model with latent confounder ``z`` and binary
treatment ``t``::
z ~ p(z) # latent confounder
x ~ p(x|z) # partial noisy observation of z
t ~ p(t|z) # treatment, whose application is biased by z
y ~ p(y|t,z) # outcome
Each of these distributions is defined by a neural network. The ``y``
distribution is defined by a disjoint pair of neural networks defining
``p(y|t=0,z)`` and ``p(y|t=1,z)``; this allows highly imbalanced treatment.
:param dict config: A dict specifying ``feature_dim``, ``latent_dim``,
``hidden_dim``, ``num_layers``, and ``outcome_dist``.
"""
def __init__(self, config):
self.latent_dim = config["latent_dim"]
super().__init__()
self.x_nn = DiagNormalNet(
[config["latent_dim"]]
+ [config["hidden_dim"]] * config["num_layers"]
+ [config["feature_dim"]]
)
OutcomeNet = DistributionNet.get_class(config["outcome_dist"])
# The y network is split between the two t values.
self.y0_nn = OutcomeNet(
[config["latent_dim"]] + [config["hidden_dim"]] * config["num_layers"]
)
self.y1_nn = OutcomeNet(
[config["latent_dim"]] + [config["hidden_dim"]] * config["num_layers"]
)
self.t_nn = BernoulliNet([config["latent_dim"]])
[docs] def forward(self, x, t=None, y=None, size=None):
if size is None:
size = x.size(0)
with pyro.plate("data", size, subsample=x):
z = pyro.sample("z", self.z_dist())
x = pyro.sample("x", self.x_dist(z), obs=x)
t = pyro.sample("t", self.t_dist(z), obs=t)
y = pyro.sample("y", self.y_dist(t, z), obs=y)
return y
[docs] def y_mean(self, x, t=None):
with pyro.plate("data", x.size(0)):
z = pyro.sample("z", self.z_dist())
x = pyro.sample("x", self.x_dist(z), obs=x)
t = pyro.sample("t", self.t_dist(z), obs=t)
return self.y_dist(t, z).mean
[docs] def z_dist(self):
return dist.Normal(0, 1).expand([self.latent_dim]).to_event(1)
[docs] def x_dist(self, z):
loc, scale = self.x_nn(z)
return dist.Normal(loc, scale).to_event(1)
[docs] def y_dist(self, t, z):
# Parameters are not shared among t values.
params0 = self.y0_nn(z)
params1 = self.y1_nn(z)
t = t.bool()
params = [torch.where(t, p1, p0) for p0, p1 in zip(params0, params1)]
return self.y0_nn.make_dist(*params)
[docs] def t_dist(self, z):
(logits,) = self.t_nn(z)
return dist.Bernoulli(logits=logits)
[docs]class Guide(PyroModule):
"""
Inference model for causal effect estimation with latent confounder ``z``
and binary treatment ``t``::
t ~ q(t|x) # treatment
y ~ q(y|t,x) # outcome
z ~ q(z|y,t,x) # latent confounder, an embedding
Each of these distributions is defined by a neural network. The ``y`` and
``z`` distributions are defined by disjoint pairs of neural networks
defining ``p(-|t=0,...)`` and ``p(-|t=1,...)``; this allows highly
imbalanced treatment.
:param dict config: A dict specifying ``feature_dim``, ``latent_dim``,
``hidden_dim``, ``num_layers``, and ``outcome_dist``.
"""
def __init__(self, config):
self.latent_dim = config["latent_dim"]
OutcomeNet = DistributionNet.get_class(config["outcome_dist"])
super().__init__()
self.t_nn = BernoulliNet([config["feature_dim"]])
# The y and z networks both follow an architecture where the first few
# layers are shared for t in {0,1}, but the final layer is split
# between the two t values.
self.y_nn = FullyConnected(
[config["feature_dim"]]
+ [config["hidden_dim"]] * (config["num_layers"] - 1),
final_activation=nn.ELU(),
)
self.y0_nn = OutcomeNet([config["hidden_dim"]])
self.y1_nn = OutcomeNet([config["hidden_dim"]])
self.z_nn = FullyConnected(
[1 + config["feature_dim"]]
+ [config["hidden_dim"]] * (config["num_layers"] - 1),
final_activation=nn.ELU(),
)
self.z0_nn = DiagNormalNet([config["hidden_dim"], config["latent_dim"]])
self.z1_nn = DiagNormalNet([config["hidden_dim"], config["latent_dim"]])
[docs] def forward(self, x, t=None, y=None, size=None):
if size is None:
size = x.size(0)
with pyro.plate("data", size, subsample=x):
# The t and y sites are needed for prediction, and participate in
# the auxiliary CEVAE loss. We mark them auxiliary to indicate they
# do not correspond to latent variables during training.
t = pyro.sample("t", self.t_dist(x), obs=t, infer={"is_auxiliary": True})
y = pyro.sample("y", self.y_dist(t, x), obs=y, infer={"is_auxiliary": True})
# The z site participates only in the usual ELBO loss.
pyro.sample("z", self.z_dist(y, t, x))
[docs] def t_dist(self, x):
(logits,) = self.t_nn(x)
return dist.Bernoulli(logits=logits)
[docs] def y_dist(self, t, x):
# The first n-1 layers are identical for all t values.
hidden = self.y_nn(x)
# In the final layer params are not shared among t values.
params0 = self.y0_nn(hidden)
params1 = self.y1_nn(hidden)
t = t.bool()
params = [torch.where(t, p1, p0) for p0, p1 in zip(params0, params1)]
return self.y0_nn.make_dist(*params)
[docs] def z_dist(self, y, t, x):
# The first n-1 layers are identical for all t values.
y_x = torch.cat([y.unsqueeze(-1), x], dim=-1)
hidden = self.z_nn(y_x)
# In the final layer params are not shared among t values.
params0 = self.z0_nn(hidden)
params1 = self.z1_nn(hidden)
t = t.bool().unsqueeze(-1)
params = [torch.where(t, p1, p0) for p0, p1 in zip(params0, params1)]
return dist.Normal(*params).to_event(1)
[docs]class TraceCausalEffect_ELBO(Trace_ELBO):
"""
Loss function for training a :class:`CEVAE`.
From [1], the CEVAE objective (to maximize) is::
-loss = ELBO + log q(t|x) + log q(y|t,x)
"""
def _differentiable_loss_particle(self, model_trace, guide_trace):
# Construct -ELBO part.
blocked_names = [
name
for name, site in guide_trace.nodes.items()
if site["type"] == "sample" and site["is_observed"]
]
blocked_guide_trace = guide_trace.copy()
for name in blocked_names:
del blocked_guide_trace.nodes[name]
loss, surrogate_loss = super()._differentiable_loss_particle(
model_trace, blocked_guide_trace
)
# Add log q terms.
for name in blocked_names:
log_q = guide_trace.nodes[name]["log_prob_sum"]
loss = loss - torch_item(log_q)
surrogate_loss = surrogate_loss - log_q
return loss, surrogate_loss
[docs] @torch.no_grad()
def loss(self, model, guide, *args, **kwargs):
return torch_item(self.differentiable_loss(model, guide, *args, **kwargs))
[docs]class CEVAE(nn.Module):
"""
Main class implementing a Causal Effect VAE [1]. This assumes a graphical model
.. graphviz:: :graphviz_dot: neato
digraph {
Z [pos="1,2!",style=filled];
X [pos="2,1!"];
y [pos="1,0!"];
t [pos="0,1!"];
Z -> X;
Z -> t;
Z -> y;
t -> y;
}
where `t` is a binary treatment variable, `y` is an outcome, `Z` is
an unobserved confounder, and `X` is a noisy function of the hidden
confounder `Z`.
Example::
cevae = CEVAE(feature_dim=5)
cevae.fit(x_train, t_train, y_train)
ite = cevae.ite(x_test) # individual treatment effect
ate = ite.mean() # average treatment effect
:ivar Model ~CEVAE.model: Generative model.
:ivar Guide ~CEVAE.guide: Inference model.
:param int feature_dim: Dimension of the feature space `x`.
:param str outcome_dist: One of: "bernoulli" (default), "exponential", "laplace",
"normal", "studentt".
:param int latent_dim: Dimension of the latent variable `z`.
Defaults to 20.
:param int hidden_dim: Dimension of hidden layers of fully connected
networks. Defaults to 200.
:param int num_layers: Number of hidden layers in fully connected networks.
:param int num_samples: Default number of samples for the :meth:`ite`
method. Defaults to 100.
"""
def __init__(
self,
feature_dim,
outcome_dist="bernoulli",
latent_dim=20,
hidden_dim=200,
num_layers=3,
num_samples=100,
):
config = dict(
feature_dim=feature_dim,
latent_dim=latent_dim,
hidden_dim=hidden_dim,
num_layers=num_layers,
num_samples=num_samples,
)
for name, size in config.items():
if not (isinstance(size, int) and size > 0):
raise ValueError("Expected {} > 0 but got {}".format(name, size))
config["outcome_dist"] = outcome_dist
self.feature_dim = feature_dim
self.num_samples = num_samples
super().__init__()
self.model = Model(config)
self.guide = Guide(config)
[docs] def fit(
self,
x,
t,
y,
num_epochs=100,
batch_size=100,
learning_rate=1e-3,
learning_rate_decay=0.1,
weight_decay=1e-4,
log_every=100,
):
"""
Train using :class:`~pyro.infer.svi.SVI` with the
:class:`TraceCausalEffect_ELBO` loss.
:param ~torch.Tensor x:
:param ~torch.Tensor t:
:param ~torch.Tensor y:
:param int num_epochs: Number of training epochs. Defaults to 100.
:param int batch_size: Batch size. Defaults to 100.
:param float learning_rate: Learning rate. Defaults to 1e-3.
:param float learning_rate_decay: Learning rate decay over all epochs;
the per-step decay rate will depend on batch size and number of epochs
such that the initial learning rate will be ``learning_rate`` and the final
learning rate will be ``learning_rate * learning_rate_decay``.
Defaults to 0.1.
:param float weight_decay: Weight decay. Defaults to 1e-4.
:param int log_every: Log loss each this-many steps. If zero,
do not log loss. Defaults to 100.
:return: list of epoch losses
"""
assert x.dim() == 2 and x.size(-1) == self.feature_dim
assert t.shape == x.shape[:1]
assert y.shape == y.shape[:1]
self.whiten = PreWhitener(x)
dataset = TensorDataset(x, t, y)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
generator=torch.Generator(device=x.device),
)
logger.info("Training with {} minibatches per epoch".format(len(dataloader)))
num_steps = num_epochs * len(dataloader)
optim = ClippedAdam(
{
"lr": learning_rate,
"weight_decay": weight_decay,
"lrd": learning_rate_decay ** (1 / num_steps),
}
)
svi = SVI(self.model, self.guide, optim, TraceCausalEffect_ELBO())
losses = []
for epoch in range(num_epochs):
for x, t, y in dataloader:
x = self.whiten(x)
loss = svi.step(x, t, y, size=len(dataset)) / len(dataset)
if log_every and len(losses) % log_every == 0:
logger.debug(
"step {: >5d} loss = {:0.6g}".format(len(losses), loss)
)
assert not torch_isnan(loss)
losses.append(loss)
return losses
[docs] @torch.no_grad()
def ite(self, x, num_samples=None, batch_size=None):
r"""
Computes Individual Treatment Effect for a batch of data ``x``.
.. math::
ITE(x) = \mathbb E\bigl[ \mathbf y \mid \mathbf X=x, do(\mathbf t=1) \bigr]
- \mathbb E\bigl[ \mathbf y \mid \mathbf X=x, do(\mathbf t=0) \bigr]
This has complexity ``O(len(x) * num_samples ** 2)``.
:param ~torch.Tensor x: A batch of data.
:param int num_samples: The number of monte carlo samples.
Defaults to ``self.num_samples`` which defaults to ``100``.
:param int batch_size: Batch size. Defaults to ``len(x)``.
:return: A ``len(x)``-sized tensor of estimated effects.
:rtype: ~torch.Tensor
"""
if num_samples is None:
num_samples = self.num_samples
if not torch._C._get_tracing_state():
assert x.dim() == 2 and x.size(-1) == self.feature_dim
dataloader = [x] if batch_size is None else DataLoader(x, batch_size=batch_size)
logger.info("Evaluating {} minibatches".format(len(dataloader)))
result = []
for x in dataloader:
x = self.whiten(x)
with pyro.plate("num_particles", num_samples, dim=-2):
with poutine.trace() as tr, poutine.block(hide=["y", "t"]):
self.guide(x)
with poutine.do(data=dict(t=torch.zeros(()))):
y0 = poutine.replay(self.model.y_mean, tr.trace)(x)
with poutine.do(data=dict(t=torch.ones(()))):
y1 = poutine.replay(self.model.y_mean, tr.trace)(x)
ite = (y1 - y0).mean(0)
if not torch._C._get_tracing_state():
logger.debug("batch ate = {:0.6g}".format(ite.mean()))
result.append(ite)
return torch.cat(result)
[docs] def to_script_module(self):
"""
Compile this module using :func:`torch.jit.trace_module` ,
assuming self has already been fit to data.
:return: A traced version of self with an :meth:`ite` method.
:rtype: torch.jit.ScriptModule
"""
self.train(False)
fake_x = torch.randn(2, self.feature_dim)
with pyro.validation_enabled(False):
# Disable check_trace due to nondeterministic nodes.
result = torch.jit.trace_module(self, {"ite": (fake_x,)}, check_trace=False)
return result