Causal Effect VAE¶
This module implements the Causal Effect Variational Autoencoder [1], which demonstrates a number of innovations including:
a generative model for causal effect inference with hidden confounders;
a model and guide with twin neural nets to allow imbalanced treatment; and
a custom training loss that includes both ELBO terms and extra terms needed to train the guide to be able to answer counterfactual queries.
The main interface is the CEVAE
class, but users may customize by
using components Model
, Guide
,
TraceCausalEffect_ELBO
and utilities.
References
- [1] C. Louizos, U. Shalit, J. Mooij, D. Sontag, R. Zemel, M. Welling (2017).
- Causal Effect Inference with Deep Latent-Variable Models.
CEVAE Class¶
- class CEVAE(feature_dim, outcome_dist='bernoulli', latent_dim=20, hidden_dim=200, num_layers=3, num_samples=100)[source]¶
Bases:
torch.nn.modules.module.Module
Main class implementing a Causal Effect VAE [1]. This assumes a graphical model
digraph { Z [pos="1,2!",style=filled]; X [pos="2,1!"]; y [pos="1,0!"]; t [pos="0,1!"]; Z -> X; Z -> t; Z -> y; t -> y; }where t is a binary treatment variable, y is an outcome, Z is an unobserved confounder, and X is a noisy function of the hidden confounder Z.
Example:
cevae = CEVAE(feature_dim=5) cevae.fit(x_train, t_train, y_train) ite = cevae.ite(x_test) # individual treatment effect ate = ite.mean() # average treatment effect
- Variables
- Parameters
feature_dim (int) – Dimension of the feature space x.
outcome_dist (str) – One of: “bernoulli” (default), “exponential”, “laplace”, “normal”, “studentt”.
latent_dim (int) – Dimension of the latent variable z. Defaults to 20.
hidden_dim (int) – Dimension of hidden layers of fully connected networks. Defaults to 200.
num_layers (int) – Number of hidden layers in fully connected networks.
num_samples (int) – Default number of samples for the
ite()
method. Defaults to 100.
- fit(x, t, y, num_epochs=100, batch_size=100, learning_rate=0.001, learning_rate_decay=0.1, weight_decay=0.0001, log_every=100)[source]¶
Train using
SVI
with theTraceCausalEffect_ELBO
loss.- Parameters
x (Tensor) –
t (Tensor) –
y (Tensor) –
num_epochs (int) – Number of training epochs. Defaults to 100.
batch_size (int) – Batch size. Defaults to 100.
learning_rate (float) – Learning rate. Defaults to 1e-3.
learning_rate_decay (float) – Learning rate decay over all epochs; the per-step decay rate will depend on batch size and number of epochs such that the initial learning rate will be
learning_rate
and the final learning rate will belearning_rate * learning_rate_decay
. Defaults to 0.1.weight_decay (float) – Weight decay. Defaults to 1e-4.
log_every (int) – Log loss each this-many steps. If zero, do not log loss. Defaults to 100.
- Returns
list of epoch losses
- ite(x, num_samples=None, batch_size=None)[source]¶
Computes Individual Treatment Effect for a batch of data
x
.\[ITE(x) = \mathbb E\bigl[ \mathbf y \mid \mathbf X=x, do(\mathbf t=1) \bigr] - \mathbb E\bigl[ \mathbf y \mid \mathbf X=x, do(\mathbf t=0) \bigr]\]This has complexity
O(len(x) * num_samples ** 2)
.
- to_script_module()[source]¶
Compile this module using
torch.jit.trace_module()
, assuming self has already been fit to data.- Returns
A traced version of self with an
ite()
method.- Return type
CEVAE Components¶
- class Model(config)[source]¶
Bases:
pyro.nn.module.PyroModule
Generative model for a causal model with latent confounder
z
and binary treatmentt
:z ~ p(z) # latent confounder x ~ p(x|z) # partial noisy observation of z t ~ p(t|z) # treatment, whose application is biased by z y ~ p(y|t,z) # outcome
Each of these distributions is defined by a neural network. The
y
distribution is defined by a disjoint pair of neural networks definingp(y|t=0,z)
andp(y|t=1,z)
; this allows highly imbalanced treatment.- Parameters
config (dict) – A dict specifying
feature_dim
,latent_dim
,hidden_dim
,num_layers
, andoutcome_dist
.
- class Guide(config)[source]¶
Bases:
pyro.nn.module.PyroModule
Inference model for causal effect estimation with latent confounder
z
and binary treatmentt
:t ~ q(t|x) # treatment y ~ q(y|t,x) # outcome z ~ q(z|y,t,x) # latent confounder, an embedding
Each of these distributions is defined by a neural network. The
y
andz
distributions are defined by disjoint pairs of neural networks definingp(-|t=0,...)
andp(-|t=1,...)
; this allows highly imbalanced treatment.- Parameters
config (dict) – A dict specifying
feature_dim
,latent_dim
,hidden_dim
,num_layers
, andoutcome_dist
.
- class TraceCausalEffect_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]¶
Bases:
pyro.infer.trace_elbo.Trace_ELBO
Loss function for training a
CEVAE
. From [1], the CEVAE objective (to maximize) is:-loss = ELBO + log q(t|x) + log q(y|t,x)
Utilities¶
- class FullyConnected(sizes, final_activation=None)[source]¶
Bases:
torch.nn.modules.container.Sequential
Fully connected multi-layer network with ELU activations.
- class DistributionNet(*args, **kwargs)[source]¶
Bases:
torch.nn.modules.module.Module
Base class for distribution nets.
- class BernoulliNet(sizes)[source]¶
Bases:
pyro.contrib.cevae.DistributionNet
FullyConnected
network outputting a singlelogits
value.This is used to represent a conditional probability distribution of a single Bernoulli random variable conditioned on a
sizes[0]
-sized real value, for example:net = BernoulliNet([3, 4]) z = torch.randn(3) logits, = net(z) t = net.make_dist(logits).sample()
- class ExponentialNet(sizes)[source]¶
Bases:
pyro.contrib.cevae.DistributionNet
FullyConnected
network outputting a constrainedrate
.This is used to represent a conditional probability distribution of a single Normal random variable conditioned on a
sizes[0]
-size real value, for example:net = ExponentialNet([3, 4]) x = torch.randn(3) rate, = net(x) y = net.make_dist(rate).sample()
- class LaplaceNet(sizes)[source]¶
Bases:
pyro.contrib.cevae.DistributionNet
FullyConnected
network outputting a constrainedloc,scale
pair.This is used to represent a conditional probability distribution of a single Laplace random variable conditioned on a
sizes[0]
-size real value, for example:net = LaplaceNet([3, 4]) x = torch.randn(3) loc, scale = net(x) y = net.make_dist(loc, scale).sample()
- class NormalNet(sizes)[source]¶
Bases:
pyro.contrib.cevae.DistributionNet
FullyConnected
network outputting a constrainedloc,scale
pair.This is used to represent a conditional probability distribution of a single Normal random variable conditioned on a
sizes[0]
-size real value, for example:net = NormalNet([3, 4]) x = torch.randn(3) loc, scale = net(x) y = net.make_dist(loc, scale).sample()
- class StudentTNet(sizes)[source]¶
Bases:
pyro.contrib.cevae.DistributionNet
FullyConnected
network outputting a constraineddf,loc,scale
triple, with shareddf > 1
.This is used to represent a conditional probability distribution of a single Student’s t random variable conditioned on a
sizes[0]
-size real value, for example:net = StudentTNet([3, 4]) x = torch.randn(3) df, loc, scale = net(x) y = net.make_dist(df, loc, scale).sample()
- class DiagNormalNet(sizes)[source]¶
Bases:
torch.nn.modules.module.Module
FullyConnected
network outputting a constrainedloc,scale
pair.This is used to represent a conditional probability distribution of a
sizes[-1]
-sized diagonal Normal random variable conditioned on asizes[0]
-size real value, for example:net = DiagNormalNet([3, 4, 5]) z = torch.randn(3) loc, scale = net(z) x = dist.Normal(loc, scale).sample()
This is intended for the latent
z
distribution and the prewhitenedx
features, and conservatively clipsloc
andscale
values.