Causal Effect VAE

This module implements the Causal Effect Variational Autoencoder [1], which demonstrates a number of innovations including:

  • a generative model for causal effect inference with hidden confounders;

  • a model and guide with twin neural nets to allow imbalanced treatment; and

  • a custom training loss that includes both ELBO terms and extra terms needed to train the guide to be able to answer counterfactual queries.

The main interface is the CEVAE class, but users may customize by using components Model, Guide, TraceCausalEffect_ELBO and utilities.

References

[1] C. Louizos, U. Shalit, J. Mooij, D. Sontag, R. Zemel, M. Welling (2017).

CEVAE Class

class CEVAE(feature_dim, outcome_dist='bernoulli', latent_dim=20, hidden_dim=200, num_layers=3, num_samples=100)[source]

Bases: torch.nn.modules.module.Module

Main class implementing a Causal Effect VAE [1]. This assumes a graphical model

digraph {
    Z [pos="1,2!",style=filled];
    X [pos="2,1!"];
    y [pos="1,0!"];
    t [pos="0,1!"];
    Z -> X;
    Z -> t;
    Z -> y;
    t -> y;
}

where t is a binary treatment variable, y is an outcome, Z is an unobserved confounder, and X is a noisy function of the hidden confounder Z.

Example:

cevae = CEVAE(feature_dim=5)
cevae.fit(x_train, t_train, y_train)
ite = cevae.ite(x_test)  # individual treatment effect
ate = ite.mean()         # average treatment effect
Variables
  • model (Model) – Generative model.

  • guide (Guide) – Inference model.

Parameters
  • feature_dim (int) – Dimension of the feature space x.

  • outcome_dist (str) – One of: “bernoulli” (default), “exponential”, “laplace”, “normal”, “studentt”.

  • latent_dim (int) – Dimension of the latent variable z. Defaults to 20.

  • hidden_dim (int) – Dimension of hidden layers of fully connected networks. Defaults to 200.

  • num_layers (int) – Number of hidden layers in fully connected networks.

  • num_samples (int) – Default number of samples for the ite() method. Defaults to 100.

fit(x, t, y, num_epochs=100, batch_size=100, learning_rate=0.001, learning_rate_decay=0.1, weight_decay=0.0001, log_every=100)[source]

Train using SVI with the TraceCausalEffect_ELBO loss.

Parameters
  • x (Tensor) –

  • t (Tensor) –

  • y (Tensor) –

  • num_epochs (int) – Number of training epochs. Defaults to 100.

  • batch_size (int) – Batch size. Defaults to 100.

  • learning_rate (float) – Learning rate. Defaults to 1e-3.

  • learning_rate_decay (float) – Learning rate decay over all epochs; the per-step decay rate will depend on batch size and number of epochs such that the initial learning rate will be learning_rate and the final learning rate will be learning_rate * learning_rate_decay. Defaults to 0.1.

  • weight_decay (float) – Weight decay. Defaults to 1e-4.

  • log_every (int) – Log loss each this-many steps. If zero, do not log loss. Defaults to 100.

Returns

list of epoch losses

ite(x, num_samples=None, batch_size=None)[source]

Computes Individual Treatment Effect for a batch of data x.

\[ITE(x) = \mathbb E\bigl[ \mathbf y \mid \mathbf X=x, do(\mathbf t=1) \bigr] - \mathbb E\bigl[ \mathbf y \mid \mathbf X=x, do(\mathbf t=0) \bigr]\]

This has complexity O(len(x) * num_samples ** 2).

Parameters
  • x (Tensor) – A batch of data.

  • num_samples (int) – The number of monte carlo samples. Defaults to self.num_samples which defaults to 100.

  • batch_size (int) – Batch size. Defaults to len(x).

Returns

A len(x)-sized tensor of estimated effects.

Return type

Tensor

to_script_module()[source]

Compile this module using torch.jit.trace_module() , assuming self has already been fit to data.

Returns

A traced version of self with an ite() method.

Return type

torch.jit.ScriptModule

training: bool

CEVAE Components

class Model(config)[source]

Bases: pyro.nn.module.PyroModule

Generative model for a causal model with latent confounder z and binary treatment t:

z ~ p(z)      # latent confounder
x ~ p(x|z)    # partial noisy observation of z
t ~ p(t|z)    # treatment, whose application is biased by z
y ~ p(y|t,z)  # outcome

Each of these distributions is defined by a neural network. The y distribution is defined by a disjoint pair of neural networks defining p(y|t=0,z) and p(y|t=1,z); this allows highly imbalanced treatment.

Parameters

config (dict) – A dict specifying feature_dim, latent_dim, hidden_dim, num_layers, and outcome_dist.

forward(x, t=None, y=None, size=None)[source]
y_mean(x, t=None)[source]
z_dist()[source]
x_dist(z)[source]
y_dist(t, z)[source]
t_dist(z)[source]
training: bool
class Guide(config)[source]

Bases: pyro.nn.module.PyroModule

Inference model for causal effect estimation with latent confounder z and binary treatment t:

t ~ q(t|x)      # treatment
y ~ q(y|t,x)    # outcome
z ~ q(z|y,t,x)  # latent confounder, an embedding

Each of these distributions is defined by a neural network. The y and z distributions are defined by disjoint pairs of neural networks defining p(-|t=0,...) and p(-|t=1,...); this allows highly imbalanced treatment.

Parameters

config (dict) – A dict specifying feature_dim, latent_dim, hidden_dim, num_layers, and outcome_dist.

forward(x, t=None, y=None, size=None)[source]
t_dist(x)[source]
y_dist(t, x)[source]
z_dist(y, t, x)[source]
training: bool
class TraceCausalEffect_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

Bases: pyro.infer.trace_elbo.Trace_ELBO

Loss function for training a CEVAE. From [1], the CEVAE objective (to maximize) is:

-loss = ELBO + log q(t|x) + log q(y|t,x)
loss(model, guide, *args, **kwargs)[source]

Utilities

class FullyConnected(sizes, final_activation=None)[source]

Bases: torch.nn.modules.container.Sequential

Fully connected multi-layer network with ELU activations.

append(layer)[source]
class DistributionNet[source]

Bases: torch.nn.modules.module.Module

Base class for distribution nets.

static get_class(dtype)[source]

Get a subclass by a prefix of its name, e.g.:

assert DistributionNet.get_class("bernoulli") is BernoulliNet
training: bool
class BernoulliNet(sizes)[source]

Bases: pyro.contrib.cevae.DistributionNet

FullyConnected network outputting a single logits value.

This is used to represent a conditional probability distribution of a single Bernoulli random variable conditioned on a sizes[0]-sized real value, for example:

net = BernoulliNet([3, 4])
z = torch.randn(3)
logits, = net(z)
t = net.make_dist(logits).sample()
forward(x)[source]
static make_dist(logits)[source]
training: bool
class ExponentialNet(sizes)[source]

Bases: pyro.contrib.cevae.DistributionNet

FullyConnected network outputting a constrained rate.

This is used to represent a conditional probability distribution of a single Normal random variable conditioned on a sizes[0]-size real value, for example:

net = ExponentialNet([3, 4])
x = torch.randn(3)
rate, = net(x)
y = net.make_dist(rate).sample()
forward(x)[source]
static make_dist(rate)[source]
training: bool
class LaplaceNet(sizes)[source]

Bases: pyro.contrib.cevae.DistributionNet

FullyConnected network outputting a constrained loc,scale pair.

This is used to represent a conditional probability distribution of a single Laplace random variable conditioned on a sizes[0]-size real value, for example:

net = LaplaceNet([3, 4])
x = torch.randn(3)
loc, scale = net(x)
y = net.make_dist(loc, scale).sample()
forward(x)[source]
static make_dist(loc, scale)[source]
training: bool
class NormalNet(sizes)[source]

Bases: pyro.contrib.cevae.DistributionNet

FullyConnected network outputting a constrained loc,scale pair.

This is used to represent a conditional probability distribution of a single Normal random variable conditioned on a sizes[0]-size real value, for example:

net = NormalNet([3, 4])
x = torch.randn(3)
loc, scale = net(x)
y = net.make_dist(loc, scale).sample()
forward(x)[source]
static make_dist(loc, scale)[source]
training: bool
class StudentTNet(sizes)[source]

Bases: pyro.contrib.cevae.DistributionNet

FullyConnected network outputting a constrained df,loc,scale triple, with shared df > 1.

This is used to represent a conditional probability distribution of a single Student’s t random variable conditioned on a sizes[0]-size real value, for example:

net = StudentTNet([3, 4])
x = torch.randn(3)
df, loc, scale = net(x)
y = net.make_dist(df, loc, scale).sample()
forward(x)[source]
static make_dist(df, loc, scale)[source]
training: bool
class DiagNormalNet(sizes)[source]

Bases: torch.nn.modules.module.Module

FullyConnected network outputting a constrained loc,scale pair.

This is used to represent a conditional probability distribution of a sizes[-1]-sized diagonal Normal random variable conditioned on a sizes[0]-size real value, for example:

net = DiagNormalNet([3, 4, 5])
z = torch.randn(3)
loc, scale = net(z)
x = dist.Normal(loc, scale).sample()

This is intended for the latent z distribution and the prewhitened x features, and conservatively clips loc and scale values.

forward(x)[source]
training: bool