Biological Sequence Models with MuE¶
Warning
Code in pyro.contrib.mue
is under development.
This code makes no guarantee about maintaining backwards compatibility.
pyro.contrib.mue
provides modeling tools for working with biological
sequence data. In particular it implements MuE distributions, which are used as
a fully generative alternative to multiple sequence alignment-based
preprocessing.
Reference: MuE models were described in Weinstein and Marks (2021), https://www.biorxiv.org/content/10.1101/2020.07.31.231381v2.
Example MuE Models¶
Example MuE observation models.
- class ProfileHMM(latent_seq_length, alphabet_length, prior_scale=1.0, indel_prior_bias=10.0, cuda=False, pin_memory=False)[source]¶
Bases:
torch.nn.modules.module.Module
Profile HMM.
This model consists of a constant distribution (a delta function) over the regressor sequence, plus a MuE observation distribution. The priors are all Normal distributions, and are pushed through a softmax function onto the simplex.
- Parameters
latent_seq_length (int) – Length of the latent regressor sequence M. Must be greater than or equal to 1.
alphabet_length (int) – Length of the sequence alphabet (e.g. 20 for amino acids).
prior_scale (float) – Standard deviation of the prior distribution.
indel_prior_bias (float) – Mean of the prior distribution over the log probability of an indel not occurring. Higher values lead to lower probability of indels.
cuda (bool) – Transfer data onto the GPU during training.
pin_memory (bool) – Pin memory for faster GPU transfer.
- fit_svi(dataset, epochs=2, batch_size=1, scheduler=None, jit=False)[source]¶
Infer approximate posterior with stochastic variational inference.
This runs
SVI
. It is an approximate inference method useful for quickly iterating on probabilistic models.- Parameters
dataset (Dataset) – The training dataset.
epochs (int) – Number of epochs of training.
batch_size (int) – Minibatch size (number of sequences).
scheduler (pyro.optim.MultiStepLR) – Optimization scheduler. (Default: Adam optimizer, 0.01 constant learning rate.)
jit (bool) – Whether to use a jit compiled ELBO.
- class FactorMuE(data_length, alphabet_length, z_dim, batch_size=10, latent_seq_length=None, indel_factor_dependence=False, indel_prior_scale=1.0, indel_prior_bias=10.0, inverse_temp_prior=100.0, weights_prior_scale=1.0, offset_prior_scale=1.0, z_prior_distribution='Normal', ARD_prior=False, substitution_matrix=True, substitution_prior_scale=10.0, latent_alphabet_length=None, cuda=False, pin_memory=False, epsilon=1e-32)[source]¶
Bases:
torch.nn.modules.module.Module
This model consists of probabilistic PCA plus a MuE output distribution.
The priors are all Normal distributions, and where relevant pushed through a softmax onto the simplex.
- Parameters
data_length (int) – Length of the input sequence matrix, including zero padding at the end.
alphabet_length (int) – Length of the sequence alphabet (e.g. 20 for amino acids).
z_dim (int) – Number of dimensions of the z space.
batch_size (int) – Minibatch size.
latent_seq_length (int) – Length of the latent regressor sequence (M). Must be greater than or equal to 1. (Default: 1.1 x data_length.)
indel_factor_dependence (bool) – Indel probabilities depend on the latent variable z.
indel_prior_scale (float) – Standard deviation of the prior distribution on indel parameters.
indel_prior_bias (float) – Mean of the prior distribution over the log probability of an indel not occurring. Higher values lead to lower probability of indels.
inverse_temp_prior (float) – Mean of the prior distribution over the inverse temperature parameter.
weights_prior_scale (float) – Standard deviation of the prior distribution over the factors.
offset_prior_scale (float) – Standard deviation of the prior distribution over the offset (constant) in the pPCA model.
z_prior_distribution (str) – Prior distribution over the latent variable z. Either ‘Normal’ (pPCA model) or ‘Laplace’ (an ICA model).
ARD_prior (bool) – Use automatic relevance determination prior on factors.
substitution_matrix (bool) – Use a learnable substitution matrix rather than the identity matrix.
substitution_prior_scale (float) – Standard deviation of the prior distribution over substitution matrix parameters (when substitution_matrix is True).
latent_alphabet_length (int) – Length of the alphabet in the latent regressor sequence.
cuda (bool) – Transfer data onto the GPU during training.
pin_memory (bool) – Pin memory for faster GPU transfer.
epsilon (float) – A small value for numerical stability.
- fit_svi(dataset, epochs=2, anneal_length=1.0, batch_size=None, scheduler=None, jit=False)[source]¶
Infer approximate posterior with stochastic variational inference.
This runs
SVI
. It is an approximate inference method useful for quickly iterating on probabilistic models.- Parameters
dataset (Dataset) – The training dataset.
epochs (int) – Number of epochs of training.
anneal_length (float) – Number of epochs over which to linearly anneal the prior KL divergence weight from 0 to 1, for improved training.
batch_size (int) – Minibatch size (number of sequences).
scheduler (pyro.optim.MultiStepLR) – Optimization scheduler. (Default: Adam optimizer, 0.01 constant learning rate.)
jit (bool) – Whether to use a jit compiled ELBO.
State Arrangers for Parameterizing MuEs¶
- class Profile(M, epsilon=1e-32)[source]¶
Bases:
torch.nn.modules.module.Module
Profile HMM state arrangement. Parameterizes an HMM according to Equation S40 in [1] (with r_{M+1,j} = 1 and u_{M+1,j} = 0 for j in {0, 1, 2}). For further background on profile HMMs see [2].
References
[1] E. N. Weinstein, D. S. Marks (2021) “Generative probabilistic biological sequence models that account for mutational variability” https://www.biorxiv.org/content/10.1101/2020.07.31.231381v2.full.pdf
[2] R. Durbin, S. R. Eddy, A. Krogh, and G. Mitchison (1998) “Biological sequence analysis: probabilistic models of proteins and nucleic acids” Cambridge university press
- Parameters
- forward(precursor_seq_logits, insert_seq_logits, insert_logits, delete_logits, substitute_logits=None)[source]¶
Assemble HMM parameters given profile parameters.
- Parameters
precursor_seq_logits (Tensor) – Regressor sequence log(x). Should have rightmost dimension
(M, D)
and be broadcastable to(batch_size, M, D)
, where D is the latent alphabet size. Should be normalized to one along the final axis, i.e.precursor_seq_logits.logsumexp(-1) = zeros
.insert_seq_logits (Tensor) – Insertion sequence log(c). Should have rightmost dimension
(M+1, D)
and be broadcastable to(batch_size, M+1, D)
. Should be normalized along the final axis.insert_logits (Tensor) – Insertion probabilities log(r). Should have rightmost dimension
(M, 3, 2)
and be broadcastable to(batch_size, M, 3, 2)
. Should be normalized along the final axis.delete_logits (Tensor) – Deletion probabilities log(u). Should have rightmost dimension
(M, 3, 2)
and be broadcastable to(batch_size, M, 3, 2)
. Should be normalized along the final axis.substitute_logits (Tensor) – Substitution probabilities log(l). Should have rightmost dimension
(D, B)
, where B is the alphabet size of the data, and broadcastable to(batch_size, D, B)
. Must be normalized along the final axis.
- Returns
initial_logits, transition_logits, and observation_logits. These parameters can be used to directly initialize the MissingDataDiscreteHMM distribution.
- Return type
Missing or Variable Length Data HMM¶
- class MissingDataDiscreteHMM(initial_logits, transition_logits, observation_logits, validate_args=None)[source]¶
Bases:
pyro.distributions.torch_distribution.TorchDistribution
HMM with discrete latent states and discrete observations, allowing for missing data or variable length sequences. Observations are assumed to be one hot encoded; rows with all zeros indicate missing data.
Warning
Unlike in pyro’s pyro.distributions.DiscreteHMM, which computes the probability of the first state as initial.T @ transition @ emission this distribution uses the standard HMM convention, initial.T @ emission
- Parameters
initial_logits (Tensor) – A logits tensor for an initial categorical distribution over latent states. Should have rightmost size
state_dim
and be broadcastable to(batch_size, state_dim)
.transition_logits (Tensor) – A logits tensor for transition conditional distributions between latent states. Should have rightmost shape
(state_dim, state_dim)
(old, new), and be broadcastable to(batch_size, state_dim, state_dim)
.observation_logits (Tensor) – A logits tensor for observation distributions from latent states. Should have rightmost shape
(state_dim, categorical_size)
, wherecategorical_size
is the dimension of the categorical output, and be broadcastable to(batch_size, state_dim, categorical_size)
.
- log_prob(value)[source]¶
- Parameters
value (Tensor) – One-hot encoded observation. Must be real-valued (float) and broadcastable to
(batch_size, num_steps, categorical_size)
wherecategorical_size
is the dimension of the categorical output. Missing data is represented by zeros, i.e.value[batch, step, :] == tensor([0, ..., 0])
. Variable length observation sequences can be handled by padding the sequence with zeros at the end.
- sample(sample_shape=torch.Size([]))[source]¶
- Parameters
sample_shape (Size) – Sample shape, last dimension must be
num_steps
and must be broadcastable to(batch_size, num_steps)
. batch_size must be int not tuple.
- filter(value)[source]¶
Compute the marginal probability of the state variable at each step conditional on the previous observations.
- Parameters
value (Tensor) – One-hot encoded observation. Must be real-valued (float) and broadcastable to
(batch_size, num_steps, categorical_size)
wherecategorical_size
is the dimension of the categorical output.
- smooth(value)[source]¶
Compute posterior expected value of state at each position (smoothing).
- Parameters
value (Tensor) – One-hot encoded observation. Must be real-valued (float) and broadcastable to
(batch_size, num_steps, categorical_size)
wherecategorical_size
is the dimension of the categorical output.
- sample_states(value)[source]¶
Sample states with forward filtering-backward sampling algorithm.
- Parameters
value (Tensor) – One-hot encoded observation. Must be real-valued (float) and broadcastable to
(batch_size, num_steps, categorical_size)
wherecategorical_size
is the dimension of the categorical output.
- map_states(value)[source]¶
Compute maximum a posteriori (MAP) estimate of state variable with Viterbi algorithm.
- Parameters
value (Tensor) – One-hot encoded observation. Must be real-valued (float) and broadcastable to
(batch_size, num_steps, categorical_size)
wherecategorical_size
is the dimension of the categorical output.
Biosequence Dataset Loading¶
- class BiosequenceDataset(source, source_type='list', alphabet='amino-acid', max_length=None, include_stop=False, device=None)[source]¶
Bases:
Generic
[torch.utils.data.dataset.T_co
]Load biological sequence data, either from a fasta file or a python list.
- Parameters
source – Either the input fasta file path (str) or the input list of sequences (list of str).
source_type (str) – Type of input, either ‘list’ or ‘fasta’.
alphabet (str) – Alphabet to use. Alphabets ‘amino-acid’ and ‘dna’ are preset; any other input will be interpreted as the alphabet itself, i.e. you can use ‘ACGU’ for RNA.
max_length (int) – Total length of the one-hot representation of the sequences, including zero padding. Defaults to the maximum sequence length in the dataset.
include_stop (bool) – Append stop symbol to the end of each sequence and add the stop symbol to the alphabet.
device (torch.device) – Device on which data should be stored in memory.
- write(x, alphabet, file, truncate_stop=False, append=False, scores=None)[source]¶
Write sequence samples to file.
- Parameters
x (Tensor) – One-hot encoded sequences, with size
(data_size, seq_length, alphabet_length)
. May be padded with zeros for variable length sequences.alphabet (array) – Alphabet.
file (str) – Output file, where sequences will be written in fasta format.
truncate_stop (bool) – If True, sequences will be truncated at the first stop symbol (i.e. the stop symbol and everything after will not be written). If False, the whole sequence will be written, including any internal stop symbols.
append (bool) – If True, sequences are appended to the end of the output file. If False, the file is first erased.