Causal Effect VAE¶
This module implements the Causal Effect Variational Autoencoder [1], which demonstrates a number of innovations including:
- a generative model for causal effect inference with hidden confounders;
- a model and guide with twin neural nets to allow imbalanced treatment; and
- a custom training loss that includes both ELBO terms and extra terms needed to train the guide to be able to answer counterfactual queries.
The main interface is the CEVAE class, but users may customize by
using components Model, Guide,
TraceCausalEffect_ELBO and utilities.
References
- [1] C. Louizos, U. Shalit, J. Mooij, D. Sontag, R. Zemel, M. Welling (2017).
- Causal Effect Inference with Deep Latent-Variable Models.
CEVAE Class¶
-
class
CEVAE(feature_dim, outcome_dist='bernoulli', latent_dim=20, hidden_dim=200, num_layers=3, num_samples=100)[source]¶ Bases:
torch.nn.modules.module.ModuleMain class implementing a Causal Effect VAE [1]. This assumes a graphical model
![digraph {
Z [pos="1,2!",style=filled];
X [pos="2,1!"];
y [pos="1,0!"];
t [pos="0,1!"];
Z -> X;
Z -> t;
Z -> y;
t -> y;
}](_images/graphviz-3f5ff256b10280b7bef0713ce8a32a2e8780d865.png)
where t is a binary treatment variable, y is an outcome, Z is an unobserved confounder, and X is a noisy function of the hidden confounder Z.
Example:
cevae = CEVAE(feature_dim=5) cevae.fit(x_train, t_train, y_train) ite = cevae.ite(x_test) # individual treatment effect ate = ite.mean() # average treatment effect
Variables: Parameters: - feature_dim (int) – Dimension of the feature space x.
- outcome_dist (str) – One of: “bernoulli” (default), “exponential”, “laplace”, “normal”, “studentt”.
- latent_dim (int) – Dimension of the latent variable z. Defaults to 20.
- hidden_dim (int) – Dimension of hidden layers of fully connected networks. Defaults to 200.
- num_layers (int) – Number of hidden layers in fully connected networks.
- num_samples (int) – Default number of samples for the
ite()method. Defaults to 100.
-
fit(x, t, y, num_epochs=100, batch_size=100, learning_rate=0.001, learning_rate_decay=0.1, weight_decay=0.0001)[source]¶ Train using
SVIwith theTraceCausalEffect_ELBOloss.Parameters: - x (Tensor) –
- t (Tensor) –
- y (Tensor) –
- num_epochs (int) – Number of training epochs. Defaults to 100.
- batch_size (int) – Batch size. Defaults to 100.
- learning_rate (float) – Learning rate. Defaults to 1e-3.
- learning_rate_decay (float) – Learning rate decay over all epochs;
the per-step decay rate will depend on batch size and number of epochs
such that the initial learning rate will be
learning_rateand the final learning rate will belearning_rate * learning_rate_decay. Defaults to 0.1. - weight_decay (float) – Weight decay. Defaults to 1e-4.
Returns: list of epoch losses
-
ite(x, num_samples=None, batch_size=None)[source]¶ Computes Individual Treatment Effect for a batch of data
x.\[ITE(x) = \mathbb E\bigl[ \mathbf y \mid \mathbf X=x, do(\mathbf t=1) \bigr] - \mathbb E\bigl[ \mathbf y \mid \mathbf X=x, do(\mathbf t=0) \bigr]\]This has complexity
O(len(x) * num_samples ** 2).Parameters: Returns: A
len(x)-sized tensor of estimated effects.Return type:
-
to_script_module()[source]¶ Compile this module using
torch.jit.trace_module(), assuming self has already been fit to data.Returns: A traced version of self with an ite()method.Return type: torch.jit.ScriptModule
CEVAE Components¶
-
class
Model(config)[source]¶ Bases:
pyro.nn.module.PyroModuleGenerative model for a causal model with latent confounder
zand binary treatmentt:z ~ p(z) # latent confounder x ~ p(x|z) # partial noisy observation of z t ~ p(t|z) # treatment, whose application is biased by z y ~ p(y|t,z) # outcome
Each of these distributions is defined by a neural network. The
ydistribution is defined by a disjoint pair of neural networks definingp(y|t=0,z)andp(y|t=1,z); this allows highly imbalanced treatment.Parameters: config (dict) – A dict specifying feature_dim,latent_dim,hidden_dim,num_layers, andoutcome_dist.
-
class
Guide(config)[source]¶ Bases:
pyro.nn.module.PyroModuleInference model for causal effect estimation with latent confounder
zand binary treatmentt:t ~ p(t|x) # treatment y ~ p(y|t,x) # outcome z ~ p(t|y,t,x) # latent confounder, an embedding
Each of these distributions is defined by a neural network. The
yandzdistributions are defined by disjoint pairs of neural networks definingp(-|t=0,...)andp(-|t=1,...); this allows highly imbalanced treatment.Parameters: config (dict) – A dict specifying feature_dim,latent_dim,hidden_dim,num_layers, andoutcome_dist.
-
class
TraceCausalEffect_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=-1.0)[source]¶ Bases:
pyro.infer.trace_elbo.Trace_ELBOLoss function for training a
CEVAE. From [1], the CEVAE objective (to maximize) is:-loss = ELBO + log q(t|x) + log q(y|t,x)
Utilities¶
-
class
FullyConnected(sizes, final_activation=None)[source]¶ Bases:
torch.nn.modules.container.SequentialFully connected multi-layer network with ELU activations.
-
class
DistributionNet[source]¶ Bases:
torch.nn.modules.module.ModuleBase class for distribution nets.
-
class
BernoulliNet(sizes)[source]¶ Bases:
pyro.contrib.cevae.DistributionNetFullyConnectednetwork outputting a singlelogitsvalue.This is used to represent a conditional probability distribution of a single Bernoulli random variable conditioned on a
sizes[0]-sized real value, for example:net = BernoulliNet([3, 4]) z = torch.randn(3) logits, = net(z) t = net.make_dist(logits).sample()
-
class
ExponentialNet(sizes)[source]¶ Bases:
pyro.contrib.cevae.DistributionNetFullyConnectednetwork outputting a constrainedrate.This is used to represent a conditional probability distribution of a single Normal random variable conditioned on a
sizes[0]-size real value, for example:net = ExponentialNet([3, 4]) x = torch.randn(3) rate, = net(x) y = net.make_dist(rate).sample()
-
class
LaplaceNet(sizes)[source]¶ Bases:
pyro.contrib.cevae.DistributionNetFullyConnectednetwork outputting a constrainedloc,scalepair.This is used to represent a conditional probability distribution of a single Laplace random variable conditioned on a
sizes[0]-size real value, for example:net = LaplaceNet([3, 4]) x = torch.randn(3) loc, scale = net(x) y = net.make_dist(loc, scale).sample()
-
class
NormalNet(sizes)[source]¶ Bases:
pyro.contrib.cevae.DistributionNetFullyConnectednetwork outputting a constrainedloc,scalepair.This is used to represent a conditional probability distribution of a single Normal random variable conditioned on a
sizes[0]-size real value, for example:net = NormalNet([3, 4]) x = torch.randn(3) loc, scale = net(x) y = net.make_dist(loc, scale).sample()
-
class
StudentTNet(sizes)[source]¶ Bases:
pyro.contrib.cevae.DistributionNetFullyConnectednetwork outputting a constraineddf,loc,scaletriple, with shareddf > 1.This is used to represent a conditional probability distribution of a single Student’s t random variable conditioned on a
sizes[0]-size real value, for example:net = StudentTNet([3, 4]) x = torch.randn(3) df, loc, scale = net(x) y = net.make_dist(df, loc, scale).sample()
-
class
DiagNormalNet(sizes)[source]¶ Bases:
torch.nn.modules.module.ModuleFullyConnectednetwork outputting a constrainedloc,scalepair.This is used to represent a conditional probability distribution of a
sizes[-1]-sized diagonal Normal random variable conditioned on asizes[0]-size real value, for example:net = DiagNormalNet([3, 4, 5]) z = torch.randn(3) loc, scale = net(z) x = dist.Normal(loc, scale).sample()
This is intended for the latent
zdistribution and the prewhitenedxfeatures, and conservatively clipslocandscalevalues.