# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
"""
Example MuE observation models.
"""
import datetime
import numpy as np
import torch
import torch.nn as nn
from torch.nn.functional import softplus
from torch.optim import Adam
from torch.utils.data import DataLoader
import pyro
import pyro.distributions as dist
from pyro import poutine
from pyro.contrib.mue.missingdatahmm import MissingDataDiscreteHMM
from pyro.contrib.mue.statearrangers import Profile
from pyro.infer import SVI, JitTrace_ELBO, Trace_ELBO
from pyro.optim import MultiStepLR
[docs]class ProfileHMM(nn.Module):
"""
Profile HMM.
This model consists of a constant distribution (a delta function) over the
regressor sequence, plus a MuE observation distribution. The priors
are all Normal distributions, and are pushed through a softmax function
onto the simplex.
:param int latent_seq_length: Length of the latent regressor sequence M.
Must be greater than or equal to 1.
:param int alphabet_length: Length of the sequence alphabet (e.g. 20 for
amino acids).
:param float prior_scale: Standard deviation of the prior distribution.
:param float indel_prior_bias: Mean of the prior distribution over the
log probability of an indel not occurring. Higher values lead to lower
probability of indels.
:param bool cuda: Transfer data onto the GPU during training.
:param bool pin_memory: Pin memory for faster GPU transfer.
"""
def __init__(
self,
latent_seq_length,
alphabet_length,
prior_scale=1.0,
indel_prior_bias=10.0,
cuda=False,
pin_memory=False,
):
super().__init__()
assert isinstance(cuda, bool)
self.is_cuda = cuda
assert isinstance(pin_memory, bool)
self.pin_memory = pin_memory
assert isinstance(latent_seq_length, int) and latent_seq_length > 0
self.latent_seq_length = latent_seq_length
assert isinstance(alphabet_length, int) and alphabet_length > 0
self.alphabet_length = alphabet_length
self.precursor_seq_shape = (latent_seq_length, alphabet_length)
self.insert_seq_shape = (latent_seq_length + 1, alphabet_length)
self.indel_shape = (latent_seq_length, 3, 2)
assert isinstance(prior_scale, float)
self.prior_scale = prior_scale
assert isinstance(indel_prior_bias, float)
self.indel_prior = torch.tensor([indel_prior_bias, 0.0])
# Initialize state arranger.
self.statearrange = Profile(latent_seq_length)
def model(self, seq_data, local_scale):
# Latent sequence.
precursor_seq = pyro.sample(
"precursor_seq",
dist.Normal(
torch.zeros(self.precursor_seq_shape),
self.prior_scale * torch.ones(self.precursor_seq_shape),
).to_event(2),
)
precursor_seq_logits = precursor_seq - precursor_seq.logsumexp(-1, True)
insert_seq = pyro.sample(
"insert_seq",
dist.Normal(
torch.zeros(self.insert_seq_shape),
self.prior_scale * torch.ones(self.insert_seq_shape),
).to_event(2),
)
insert_seq_logits = insert_seq - insert_seq.logsumexp(-1, True)
# Indel probabilities.
insert = pyro.sample(
"insert",
dist.Normal(
self.indel_prior * torch.ones(self.indel_shape),
self.prior_scale * torch.ones(self.indel_shape),
).to_event(3),
)
insert_logits = insert - insert.logsumexp(-1, True)
delete = pyro.sample(
"delete",
dist.Normal(
self.indel_prior * torch.ones(self.indel_shape),
self.prior_scale * torch.ones(self.indel_shape),
).to_event(3),
)
delete_logits = delete - delete.logsumexp(-1, True)
# Construct HMM parameters.
initial_logits, transition_logits, observation_logits = self.statearrange(
precursor_seq_logits, insert_seq_logits, insert_logits, delete_logits
)
with pyro.plate("batch", seq_data.shape[0]):
with poutine.scale(scale=local_scale):
# Observations.
pyro.sample(
"obs_seq",
MissingDataDiscreteHMM(
initial_logits, transition_logits, observation_logits
),
obs=seq_data,
)
def guide(self, seq_data, local_scale):
# Sequence.
precursor_seq_q_mn = pyro.param(
"precursor_seq_q_mn", torch.zeros(self.precursor_seq_shape)
)
precursor_seq_q_sd = pyro.param(
"precursor_seq_q_sd", torch.zeros(self.precursor_seq_shape)
)
pyro.sample(
"precursor_seq",
dist.Normal(precursor_seq_q_mn, softplus(precursor_seq_q_sd)).to_event(2),
)
insert_seq_q_mn = pyro.param(
"insert_seq_q_mn", torch.zeros(self.insert_seq_shape)
)
insert_seq_q_sd = pyro.param(
"insert_seq_q_sd", torch.zeros(self.insert_seq_shape)
)
pyro.sample(
"insert_seq",
dist.Normal(insert_seq_q_mn, softplus(insert_seq_q_sd)).to_event(2),
)
# Indels.
insert_q_mn = pyro.param(
"insert_q_mn", torch.ones(self.indel_shape) * self.indel_prior
)
insert_q_sd = pyro.param("insert_q_sd", torch.zeros(self.indel_shape))
pyro.sample(
"insert",
dist.Normal(insert_q_mn, softplus(insert_q_sd)).to_event(3),
)
delete_q_mn = pyro.param(
"delete_q_mn", torch.ones(self.indel_shape) * self.indel_prior
)
delete_q_sd = pyro.param("delete_q_sd", torch.zeros(self.indel_shape))
pyro.sample(
"delete",
dist.Normal(delete_q_mn, softplus(delete_q_sd)).to_event(3),
)
[docs] def fit_svi(
self,
dataset,
epochs=2,
batch_size=1,
scheduler=None,
jit=False,
):
"""
Infer approximate posterior with stochastic variational inference.
This runs :class:`~pyro.infer.svi.SVI`. It is an approximate inference
method useful for quickly iterating on probabilistic models.
:param ~torch.utils.data.Dataset dataset: The training dataset.
:param int epochs: Number of epochs of training.
:param int batch_size: Minibatch size (number of sequences).
:param pyro.optim.MultiStepLR scheduler: Optimization scheduler.
(Default: Adam optimizer, 0.01 constant learning rate.)
:param bool jit: Whether to use a jit compiled ELBO.
"""
# Setup.
if batch_size is not None:
self.batch_size = batch_size
if scheduler is None:
scheduler = MultiStepLR(
{
"optimizer": Adam,
"optim_args": {"lr": 0.01},
"milestones": [],
"gamma": 0.5,
}
)
if self.is_cuda:
device = torch.device("cuda")
else:
device = torch.device("cpu")
# Initialize guide.
self.guide(None, None)
dataload = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=self.pin_memory,
generator=torch.Generator(device=device),
)
# Setup stochastic variational inference.
if jit:
elbo = JitTrace_ELBO(ignore_jit_warnings=True)
else:
elbo = Trace_ELBO()
svi = SVI(self.model, self.guide, scheduler, loss=elbo)
# Run inference.
losses = []
t0 = datetime.datetime.now()
for epoch in range(epochs):
for seq_data, L_data in dataload:
if self.is_cuda:
seq_data = seq_data.cuda()
loss = svi.step(
seq_data, torch.tensor(len(dataset) / seq_data.shape[0])
)
losses.append(loss)
scheduler.step()
print(epoch, loss, " ", datetime.datetime.now() - t0)
return losses
[docs] def evaluate(self, dataset_train, dataset_test=None, jit=False):
"""
Evaluate performance (log probability and per residue perplexity) on
train and test datasets.
:param ~torch.utils.data.Dataset dataset: The training dataset.
:param ~torch.utils.data.Dataset dataset: The testing dataset.
:param bool jit: Whether to use a jit compiled ELBO.
"""
dataload_train = DataLoader(dataset_train, batch_size=1, shuffle=False)
if dataset_test is not None:
dataload_test = DataLoader(dataset_test, batch_size=1, shuffle=False)
# Initialize guide.
self.guide(None, None)
if jit:
elbo = JitTrace_ELBO(ignore_jit_warnings=True)
else:
elbo = Trace_ELBO()
scheduler = MultiStepLR({"optimizer": Adam, "optim_args": {"lr": 0.01}})
# Setup stochastic variational inference.
svi = SVI(self.model, self.guide, scheduler, loss=elbo)
# Compute elbo and perplexity.
train_lp, train_perplex = self._evaluate_local_elbo(
svi, dataload_train, len(dataset_train)
)
if dataset_test is not None:
test_lp, test_perplex = self._evaluate_local_elbo(
svi, dataload_test, len(dataset_test)
)
return train_lp, test_lp, train_perplex, test_perplex
else:
return train_lp, None, train_perplex, None
def _local_variables(self, name, site):
"""Return per datapoint random variables in model."""
return name in ["obs_L", "obs_seq"]
def _evaluate_local_elbo(self, svi, dataload, data_size):
"""Evaluate elbo and average per residue perplexity."""
lp, perplex = 0.0, 0.0
with torch.no_grad():
for seq_data, L_data in dataload:
if self.is_cuda:
seq_data, L_data = seq_data.cuda(), L_data.cuda()
conditioned_model = poutine.condition(
self.model, data={"obs_seq": seq_data}
)
args = (seq_data, torch.tensor(1.0))
guide_tr = poutine.trace(self.guide).get_trace(*args)
model_tr = poutine.trace(
poutine.replay(conditioned_model, trace=guide_tr)
).get_trace(*args)
local_elbo = (
(
model_tr.log_prob_sum(self._local_variables)
- guide_tr.log_prob_sum(self._local_variables)
)
.cpu()
.numpy()
)
lp += local_elbo
perplex += -local_elbo / L_data[0].cpu().numpy()
perplex = np.exp(perplex / data_size)
return lp, perplex
class Encoder(nn.Module):
def __init__(self, data_length, alphabet_length, z_dim):
super().__init__()
self.input_size = data_length * alphabet_length
self.f1_mn = nn.Linear(self.input_size, z_dim)
self.f1_sd = nn.Linear(self.input_size, z_dim)
def forward(self, data):
data = data.reshape(-1, self.input_size)
z_loc = self.f1_mn(data)
z_scale = softplus(self.f1_sd(data))
return z_loc, z_scale
[docs]class FactorMuE(nn.Module):
"""
FactorMuE
This model consists of probabilistic PCA plus a MuE output distribution.
The priors are all Normal distributions, and where relevant pushed through
a softmax onto the simplex.
:param int data_length: Length of the input sequence matrix, including
zero padding at the end.
:param int alphabet_length: Length of the sequence alphabet (e.g. 20 for
amino acids).
:param int z_dim: Number of dimensions of the z space.
:param int batch_size: Minibatch size.
:param int latent_seq_length: Length of the latent regressor sequence (M).
Must be greater than or equal to 1. (Default: 1.1 x data_length.)
:param bool indel_factor_dependence: Indel probabilities depend on the
latent variable z.
:param float indel_prior_scale: Standard deviation of the prior
distribution on indel parameters.
:param float indel_prior_bias: Mean of the prior distribution over the
log probability of an indel not occurring. Higher values lead to lower
probability of indels.
:param float inverse_temp_prior: Mean of the prior distribution over the
inverse temperature parameter.
:param float weights_prior_scale: Standard deviation of the prior
distribution over the factors.
:param float offset_prior_scale: Standard deviation of the prior
distribution over the offset (constant) in the pPCA model.
:param str z_prior_distribution: Prior distribution over the latent
variable z. Either 'Normal' (pPCA model) or 'Laplace' (an ICA model).
:param bool ARD_prior: Use automatic relevance determination prior on
factors.
:param bool substitution_matrix: Use a learnable substitution matrix
rather than the identity matrix.
:param float substitution_prior_scale: Standard deviation of the prior
distribution over substitution matrix parameters (when
substitution_matrix is True).
:param int latent_alphabet_length: Length of the alphabet in the latent
regressor sequence.
:param bool cuda: Transfer data onto the GPU during training.
:param bool pin_memory: Pin memory for faster GPU transfer.
:param float epsilon: A small value for numerical stability.
"""
def __init__(
self,
data_length,
alphabet_length,
z_dim,
batch_size=10,
latent_seq_length=None,
indel_factor_dependence=False,
indel_prior_scale=1.0,
indel_prior_bias=10.0,
inverse_temp_prior=100.0,
weights_prior_scale=1.0,
offset_prior_scale=1.0,
z_prior_distribution="Normal",
ARD_prior=False,
substitution_matrix=True,
substitution_prior_scale=10.0,
latent_alphabet_length=None,
cuda=False,
pin_memory=False,
epsilon=1e-32,
):
super().__init__()
assert isinstance(cuda, bool)
self.is_cuda = cuda
assert isinstance(pin_memory, bool)
self.pin_memory = pin_memory
# Constants.
assert isinstance(data_length, int) and data_length > 0
self.data_length = data_length
if latent_seq_length is None:
latent_seq_length = int(data_length * 1.1)
else:
assert isinstance(latent_seq_length, int) and latent_seq_length > 0
self.latent_seq_length = latent_seq_length
assert isinstance(alphabet_length, int) and alphabet_length > 0
self.alphabet_length = alphabet_length
assert isinstance(z_dim, int) and z_dim > 0
self.z_dim = z_dim
# Parameter shapes.
if (not substitution_matrix) or (latent_alphabet_length is None):
latent_alphabet_length = alphabet_length
self.latent_alphabet_length = latent_alphabet_length
self.indel_shape = (latent_seq_length, 3, 2)
self.total_factor_size = (
(2 * latent_seq_length + 1) * latent_alphabet_length
+ 2 * indel_factor_dependence * latent_seq_length * 3 * 2
)
# Architecture.
self.indel_factor_dependence = indel_factor_dependence
self.ARD_prior = ARD_prior
self.substitution_matrix = substitution_matrix
# Priors.
assert isinstance(indel_prior_scale, float)
self.indel_prior_scale = torch.tensor(indel_prior_scale)
assert isinstance(indel_prior_bias, float)
self.indel_prior = torch.tensor([indel_prior_bias, 0.0])
assert isinstance(inverse_temp_prior, float)
self.inverse_temp_prior = torch.tensor(inverse_temp_prior)
assert isinstance(weights_prior_scale, float)
self.weights_prior_scale = torch.tensor(weights_prior_scale)
assert isinstance(offset_prior_scale, float)
self.offset_prior_scale = torch.tensor(offset_prior_scale)
assert isinstance(epsilon, float)
self.epsilon = torch.tensor(epsilon)
assert isinstance(substitution_prior_scale, float)
self.substitution_prior_scale = torch.tensor(substitution_prior_scale)
self.z_prior_distribution = z_prior_distribution
# Batch control.
assert isinstance(batch_size, int)
self.batch_size = batch_size
# Initialize layers.
self.encoder = Encoder(data_length, alphabet_length, z_dim)
self.statearrange = Profile(latent_seq_length)
def decoder(self, z, W, B, inverse_temp):
# Project.
v = torch.mm(z, W) + B
out = dict()
if self.indel_factor_dependence:
# Extract insertion and deletion parameters.
ind0 = (2 * self.latent_seq_length + 1) * self.latent_alphabet_length
ind1 = ind0 + self.latent_seq_length * 3 * 2
ind2 = ind1 + self.latent_seq_length * 3 * 2
insert_v, delete_v = v[:, ind0:ind1], v[:, ind1:ind2]
insert_v = (
insert_v.reshape([-1, self.latent_seq_length, 3, 2]) + self.indel_prior
)
out["insert_logits"] = insert_v - insert_v.logsumexp(-1, True)
delete_v = (
delete_v.reshape([-1, self.latent_seq_length, 3, 2]) + self.indel_prior
)
out["delete_logits"] = delete_v - delete_v.logsumexp(-1, True)
# Extract precursor and insertion sequences.
ind0 = self.latent_seq_length * self.latent_alphabet_length
ind1 = ind0 + (self.latent_seq_length + 1) * self.latent_alphabet_length
precursor_seq_v, insert_seq_v = v[:, :ind0], v[:, ind0:ind1]
precursor_seq_v = (precursor_seq_v * softplus(inverse_temp)).reshape(
[-1, self.latent_seq_length, self.latent_alphabet_length]
)
out["precursor_seq_logits"] = precursor_seq_v - precursor_seq_v.logsumexp(
-1, True
)
insert_seq_v = (insert_seq_v * softplus(inverse_temp)).reshape(
[-1, self.latent_seq_length + 1, self.latent_alphabet_length]
)
out["insert_seq_logits"] = insert_seq_v - insert_seq_v.logsumexp(-1, True)
return out
def model(self, seq_data, local_scale, local_prior_scale):
# ARD prior.
if self.ARD_prior:
# Relevance factors
alpha = pyro.sample(
"alpha",
dist.Gamma(torch.ones(self.z_dim), torch.ones(self.z_dim)).to_event(1),
)
else:
alpha = torch.ones(self.z_dim)
# Factor and offset.
W = pyro.sample(
"W",
dist.Normal(
torch.zeros([self.z_dim, self.total_factor_size]),
torch.ones([self.z_dim, self.total_factor_size])
* self.weights_prior_scale
/ (alpha[:, None] + self.epsilon),
).to_event(2),
)
B = pyro.sample(
"B",
dist.Normal(
torch.zeros(self.total_factor_size),
torch.ones(self.total_factor_size) * self.offset_prior_scale,
).to_event(1),
)
# Indel probabilities.
if not self.indel_factor_dependence:
insert = pyro.sample(
"insert",
dist.Normal(
self.indel_prior * torch.ones(self.indel_shape),
self.indel_prior_scale * torch.ones(self.indel_shape),
).to_event(3),
)
insert_logits = insert - insert.logsumexp(-1, True)
delete = pyro.sample(
"delete",
dist.Normal(
self.indel_prior * torch.ones(self.indel_shape),
self.indel_prior_scale * torch.ones(self.indel_shape),
).to_event(3),
)
delete_logits = delete - delete.logsumexp(-1, True)
# Inverse temperature.
inverse_temp = pyro.sample(
"inverse_temp", dist.Normal(self.inverse_temp_prior, torch.tensor(1.0))
)
# Substitution matrix.
if self.substitution_matrix:
substitute = pyro.sample(
"substitute",
dist.Normal(
torch.zeros([self.latent_alphabet_length, self.alphabet_length]),
self.substitution_prior_scale
* torch.ones([self.latent_alphabet_length, self.alphabet_length]),
).to_event(2),
)
with pyro.plate("batch", seq_data.shape[0]):
with poutine.scale(scale=local_scale):
with poutine.scale(scale=local_prior_scale):
# Sample latent variable from prior.
if self.z_prior_distribution == "Normal":
z = pyro.sample(
"latent",
dist.Normal(
torch.zeros(self.z_dim), torch.ones(self.z_dim)
).to_event(1),
)
elif self.z_prior_distribution == "Laplace":
z = pyro.sample(
"latent",
dist.Laplace(
torch.zeros(self.z_dim), torch.ones(self.z_dim)
).to_event(1),
)
# Decode latent sequence.
decoded = self.decoder(z, W, B, inverse_temp)
if self.indel_factor_dependence:
insert_logits = decoded["insert_logits"]
delete_logits = decoded["delete_logits"]
# Construct HMM parameters.
if self.substitution_matrix:
(
initial_logits,
transition_logits,
observation_logits,
) = self.statearrange(
decoded["precursor_seq_logits"],
decoded["insert_seq_logits"],
insert_logits,
delete_logits,
substitute,
)
else:
(
initial_logits,
transition_logits,
observation_logits,
) = self.statearrange(
decoded["precursor_seq_logits"],
decoded["insert_seq_logits"],
insert_logits,
delete_logits,
)
# Draw samples.
pyro.sample(
"obs_seq",
MissingDataDiscreteHMM(
initial_logits, transition_logits, observation_logits
),
obs=seq_data,
)
def guide(self, seq_data, local_scale, local_prior_scale):
# Register encoder with pyro.
pyro.module("encoder", self.encoder)
# ARD weightings.
if self.ARD_prior:
alpha_conc = pyro.param("alpha_conc", torch.randn(self.z_dim))
alpha_rate = pyro.param("alpha_rate", torch.randn(self.z_dim))
pyro.sample(
"alpha",
dist.Gamma(softplus(alpha_conc), softplus(alpha_rate)).to_event(1),
)
# Factors.
W_q_mn = pyro.param("W_q_mn", torch.randn([self.z_dim, self.total_factor_size]))
W_q_sd = pyro.param("W_q_sd", torch.ones([self.z_dim, self.total_factor_size]))
pyro.sample("W", dist.Normal(W_q_mn, softplus(W_q_sd)).to_event(2))
B_q_mn = pyro.param("B_q_mn", torch.randn(self.total_factor_size))
B_q_sd = pyro.param("B_q_sd", torch.ones(self.total_factor_size))
pyro.sample("B", dist.Normal(B_q_mn, softplus(B_q_sd)).to_event(1))
# Indel probabilities.
if not self.indel_factor_dependence:
insert_q_mn = pyro.param(
"insert_q_mn", torch.ones(self.indel_shape) * self.indel_prior
)
insert_q_sd = pyro.param("insert_q_sd", torch.zeros(self.indel_shape))
pyro.sample(
"insert", dist.Normal(insert_q_mn, softplus(insert_q_sd)).to_event(3)
)
delete_q_mn = pyro.param(
"delete_q_mn", torch.ones(self.indel_shape) * self.indel_prior
)
delete_q_sd = pyro.param("delete_q_sd", torch.zeros(self.indel_shape))
pyro.sample(
"delete", dist.Normal(delete_q_mn, softplus(delete_q_sd)).to_event(3)
)
# Inverse temperature.
inverse_temp_q_mn = pyro.param("inverse_temp_q_mn", torch.tensor(0.0))
inverse_temp_q_sd = pyro.param("inverse_temp_q_sd", torch.tensor(0.0))
pyro.sample(
"inverse_temp", dist.Normal(inverse_temp_q_mn, softplus(inverse_temp_q_sd))
)
# Substitution matrix.
if self.substitution_matrix:
substitute_q_mn = pyro.param(
"substitute_q_mn",
torch.zeros([self.latent_alphabet_length, self.alphabet_length]),
)
substitute_q_sd = pyro.param(
"substitute_q_sd",
torch.zeros([self.latent_alphabet_length, self.alphabet_length]),
)
pyro.sample(
"substitute",
dist.Normal(substitute_q_mn, softplus(substitute_q_sd)).to_event(2),
)
# Per datapoint local latent variables.
with pyro.plate("batch", seq_data.shape[0]):
# Encode sequences.
z_loc, z_scale = self.encoder(seq_data)
# Scale log likelihood to account for mini-batching.
with poutine.scale(scale=local_scale * local_prior_scale):
# Sample.
if self.z_prior_distribution == "Normal":
pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
elif self.z_prior_distribution == "Laplace":
pyro.sample("latent", dist.Laplace(z_loc, z_scale).to_event(1))
[docs] def fit_svi(
self,
dataset,
epochs=2,
anneal_length=1.0,
batch_size=None,
scheduler=None,
jit=False,
):
"""
Infer approximate posterior with stochastic variational inference.
This runs :class:`~pyro.infer.svi.SVI`. It is an approximate inference
method useful for quickly iterating on probabilistic models.
:param ~torch.utils.data.Dataset dataset: The training dataset.
:param int epochs: Number of epochs of training.
:param float anneal_length: Number of epochs over which to linearly
anneal the prior KL divergence weight from 0 to 1, for improved
training.
:param int batch_size: Minibatch size (number of sequences).
:param pyro.optim.MultiStepLR scheduler: Optimization scheduler.
(Default: Adam optimizer, 0.01 constant learning rate.)
:param bool jit: Whether to use a jit compiled ELBO.
"""
# Setup.
if batch_size is not None:
self.batch_size = batch_size
if scheduler is None:
scheduler = MultiStepLR(
{
"optimizer": Adam,
"optim_args": {"lr": 0.01},
"milestones": [],
"gamma": 0.5,
}
)
if self.is_cuda:
device = torch.device("cuda")
else:
device = torch.device("cpu")
dataload = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=self.pin_memory,
generator=torch.Generator(device=device),
)
# Initialize guide.
for seq_data, L_data in dataload:
if self.is_cuda:
seq_data = seq_data.cuda()
self.guide(seq_data, torch.tensor(1.0), torch.tensor(1.0))
break
# Setup stochastic variational inference.
if jit:
elbo = JitTrace_ELBO(ignore_jit_warnings=True)
else:
elbo = Trace_ELBO()
svi = SVI(self.model, self.guide, scheduler, loss=elbo)
# Run inference.
losses = []
step_i = 1
t0 = datetime.datetime.now()
for epoch in range(epochs):
for seq_data, L_data in dataload:
if self.is_cuda:
seq_data = seq_data.cuda()
loss = svi.step(
seq_data,
torch.tensor(len(dataset) / seq_data.shape[0]),
self._beta_anneal(step_i, batch_size, len(dataset), anneal_length),
)
losses.append(loss)
scheduler.step()
step_i += 1
print(epoch, loss, " ", datetime.datetime.now() - t0)
return losses
def _beta_anneal(self, step, batch_size, data_size, anneal_length):
"""Annealing schedule for prior KL term (beta annealing)."""
if np.allclose(anneal_length, 0.0):
return torch.tensor(1.0)
anneal_frac = step * batch_size / (anneal_length * data_size)
return torch.tensor(min([anneal_frac, 1.0]))
[docs] def evaluate(self, dataset_train, dataset_test=None, jit=False):
"""
Evaluate performance (log probability and per residue perplexity) on
train and test datasets.
:param ~torch.utils.data.Dataset dataset: The training dataset.
:param ~torch.utils.data.Dataset dataset: The testing dataset
(optional).
:param bool jit: Whether to use a jit compiled ELBO.
"""
dataload_train = DataLoader(dataset_train, batch_size=1, shuffle=False)
if dataset_test is not None:
dataload_test = DataLoader(dataset_test, batch_size=1, shuffle=False)
# Initialize guide.
for seq_data, L_data in dataload_train:
if self.is_cuda:
seq_data = seq_data.cuda()
self.guide(seq_data, torch.tensor(1.0), torch.tensor(1.0))
break
if jit:
elbo = JitTrace_ELBO(ignore_jit_warnings=True)
else:
elbo = Trace_ELBO()
scheduler = MultiStepLR({"optimizer": Adam, "optim_args": {"lr": 0.01}})
# Setup stochastic variational inference.
svi = SVI(self.model, self.guide, scheduler, loss=elbo)
# Compute elbo and perplexity.
train_lp, train_perplex = self._evaluate_local_elbo(
svi, dataload_train, len(dataset_train)
)
if dataset_test is not None:
test_lp, test_perplex = self._evaluate_local_elbo(
svi, dataload_test, len(dataset_test)
)
return train_lp, test_lp, train_perplex, test_perplex
else:
return train_lp, None, train_perplex, None
def _local_variables(self, name, site):
"""Return per datapoint random variables in model."""
return name in ["latent", "obs_L", "obs_seq"]
def _evaluate_local_elbo(self, svi, dataload, data_size):
"""Evaluate elbo and average per residue perplexity."""
lp, perplex = 0.0, 0.0
with torch.no_grad():
for seq_data, L_data in dataload:
if self.is_cuda:
seq_data, L_data = seq_data.cuda(), L_data.cuda()
conditioned_model = poutine.condition(
self.model, data={"obs_seq": seq_data}
)
args = (seq_data, torch.tensor(1.0), torch.tensor(1.0))
guide_tr = poutine.trace(self.guide).get_trace(*args)
model_tr = poutine.trace(
poutine.replay(conditioned_model, trace=guide_tr)
).get_trace(*args)
local_elbo = (
(
model_tr.log_prob_sum(self._local_variables)
- guide_tr.log_prob_sum(self._local_variables)
)
.cpu()
.numpy()
)
lp += local_elbo
perplex += -local_elbo / L_data[0].cpu().numpy()
perplex = np.exp(perplex / data_size)
return lp, perplex
[docs] def embed(self, dataset, batch_size=None):
"""
Get the latent space embedding (mean posterior value of z).
:param ~torch.utils.data.Dataset dataset: The dataset to embed.
:param int batch_size: Minibatch size (number of sequences). (Defaults
to batch_size of the model object.)
"""
if batch_size is None:
batch_size = self.batch_size
dataload = DataLoader(dataset, batch_size=batch_size, shuffle=False)
with torch.no_grad():
z_locs, z_scales = [], []
for seq_data, L_data in dataload:
if self.is_cuda:
seq_data = seq_data.cuda()
z_loc, z_scale = self.encoder(seq_data)
z_locs.append(z_loc.cpu())
z_scales.append(z_scale.cpu())
return torch.cat(z_locs), torch.cat(z_scales)
def _reconstruct_regressor_seq(self, data, ind, param):
"Reconstruct the latent regressor sequence given data."
with torch.no_grad():
# Encode seq.
z_loc = self.encoder(data[ind][0])[0]
# Reconstruct
decoded = self.decoder(
z_loc, param("W_q_mn"), param("B_q_mn"), param("inverse_temp_q_mn")
)
return torch.exp(decoded["precursor_seq_logits"])