# Biological Sequence Models with MuE¶

Warning

Code in pyro.contrib.mue is under development. This code makes no guarantee about maintaining backwards compatibility.

pyro.contrib.mue provides modeling tools for working with biological sequence data. In particular it implements MuE distributions, which are used as a fully generative alternative to multiple sequence alignment-based preprocessing.

Reference: MuE models were described in Weinstein and Marks (2021), https://www.biorxiv.org/content/10.1101/2020.07.31.231381v2.

## Example MuE Models¶

Example MuE observation models.

class ProfileHMM(latent_seq_length, alphabet_length, prior_scale=1.0, indel_prior_bias=10.0, cuda=False, pin_memory=False)[source]

Bases: torch.nn.modules.module.Module

Profile HMM.

This model consists of a constant distribution (a delta function) over the regressor sequence, plus a MuE observation distribution. The priors are all Normal distributions, and are pushed through a softmax function onto the simplex.

Parameters: latent_seq_length (int) – Length of the latent regressor sequence M. Must be greater than or equal to 1. alphabet_length (int) – Length of the sequence alphabet (e.g. 20 for amino acids). prior_scale (float) – Standard deviation of the prior distribution. indel_prior_bias (float) – Mean of the prior distribution over the log probability of an indel not occurring. Higher values lead to lower probability of indels. cuda (bool) – Transfer data onto the GPU during training. pin_memory (bool) – Pin memory for faster GPU transfer.
fit_svi(dataset, epochs=2, batch_size=1, scheduler=None, jit=False)[source]

Infer approximate posterior with stochastic variational inference.

This runs SVI. It is an approximate inference method useful for quickly iterating on probabilistic models.

Parameters: dataset (Dataset) – The training dataset. epochs (int) – Number of epochs of training. batch_size (int) – Minibatch size (number of sequences). scheduler (pyro.optim.MultiStepLR) – Optimization scheduler. (Default: Adam optimizer, 0.01 constant learning rate.) jit (bool) – Whether to use a jit compiled ELBO.
evaluate(dataset_train, dataset_test=None, jit=False)[source]

Evaluate performance (log probability and per residue perplexity) on train and test datasets.

Parameters: dataset (Dataset) – The training dataset. dataset – The testing dataset. jit (bool) – Whether to use a jit compiled ELBO.
class FactorMuE(data_length, alphabet_length, z_dim, batch_size=10, latent_seq_length=None, indel_factor_dependence=False, indel_prior_scale=1.0, indel_prior_bias=10.0, inverse_temp_prior=100.0, weights_prior_scale=1.0, offset_prior_scale=1.0, z_prior_distribution='Normal', ARD_prior=False, substitution_matrix=True, substitution_prior_scale=10.0, latent_alphabet_length=None, cuda=False, pin_memory=False, epsilon=1e-32)[source]

Bases: torch.nn.modules.module.Module

This model consists of probabilistic PCA plus a MuE output distribution.

The priors are all Normal distributions, and where relevant pushed through a softmax onto the simplex.

Parameters: data_length (int) – Length of the input sequence matrix, including zero padding at the end. alphabet_length (int) – Length of the sequence alphabet (e.g. 20 for amino acids). z_dim (int) – Number of dimensions of the z space. batch_size (int) – Minibatch size. latent_seq_length (int) – Length of the latent regressor sequence (M). Must be greater than or equal to 1. (Default: 1.1 x data_length.) indel_factor_dependence (bool) – Indel probabilities depend on the latent variable z. indel_prior_scale (float) – Standard deviation of the prior distribution on indel parameters. indel_prior_bias (float) – Mean of the prior distribution over the log probability of an indel not occurring. Higher values lead to lower probability of indels. inverse_temp_prior (float) – Mean of the prior distribution over the inverse temperature parameter. weights_prior_scale (float) – Standard deviation of the prior distribution over the factors. offset_prior_scale (float) – Standard deviation of the prior distribution over the offset (constant) in the pPCA model. z_prior_distribution (str) – Prior distribution over the latent variable z. Either ‘Normal’ (pPCA model) or ‘Laplace’ (an ICA model). ARD_prior (bool) – Use automatic relevance determination prior on factors. substitution_matrix (bool) – Use a learnable substitution matrix rather than the identity matrix. substitution_prior_scale (float) – Standard deviation of the prior distribution over substitution matrix parameters (when substitution_matrix is True). latent_alphabet_length (int) – Length of the alphabet in the latent regressor sequence. cuda (bool) – Transfer data onto the GPU during training. pin_memory (bool) – Pin memory for faster GPU transfer. epsilon (float) – A small value for numerical stability.
fit_svi(dataset, epochs=2, anneal_length=1.0, batch_size=None, scheduler=None, jit=False)[source]

Infer approximate posterior with stochastic variational inference.

This runs SVI. It is an approximate inference method useful for quickly iterating on probabilistic models.

Parameters: dataset (Dataset) – The training dataset. epochs (int) – Number of epochs of training. anneal_length (float) – Number of epochs over which to linearly anneal the prior KL divergence weight from 0 to 1, for improved training. batch_size (int) – Minibatch size (number of sequences). scheduler (pyro.optim.MultiStepLR) – Optimization scheduler. (Default: Adam optimizer, 0.01 constant learning rate.) jit (bool) – Whether to use a jit compiled ELBO.
evaluate(dataset_train, dataset_test=None, jit=False)[source]

Evaluate performance (log probability and per residue perplexity) on train and test datasets.

Parameters: dataset (Dataset) – The training dataset. dataset – The testing dataset (optional). jit (bool) – Whether to use a jit compiled ELBO.
embed(dataset, batch_size=None)[source]

Get the latent space embedding (mean posterior value of z).

Parameters: dataset (Dataset) – The dataset to embed. batch_size (int) – Minibatch size (number of sequences). (Defaults to batch_size of the model object.)

## State Arrangers for Parameterizing MuEs¶

class Profile(M, epsilon=1e-32)[source]

Bases: torch.nn.modules.module.Module

Profile HMM state arrangement. Parameterizes an HMM according to Equation S40 in [1] (with r_{M+1,j} = 1 and u_{M+1,j} = 0 for j in {0, 1, 2}). For further background on profile HMMs see [2].

References

[1] E. N. Weinstein, D. S. Marks (2021) “Generative probabilistic biological sequence models that account for mutational variability” https://www.biorxiv.org/content/10.1101/2020.07.31.231381v2.full.pdf

[2] R. Durbin, S. R. Eddy, A. Krogh, and G. Mitchison (1998) “Biological sequence analysis: probabilistic models of proteins and nucleic acids” Cambridge university press

Parameters: M (int) – Length of regressor sequence. epsilon (float) – A small value for numerical stability.
forward(precursor_seq_logits, insert_seq_logits, insert_logits, delete_logits, substitute_logits=None)[source]

Assemble HMM parameters given profile parameters.

Parameters: precursor_seq_logits (Tensor) – Regressor sequence log(x). Should have rightmost dimension (M, D) and be broadcastable to (batch_size, M, D), where D is the latent alphabet size. Should be normalized to one along the final axis, i.e. precursor_seq_logits.logsumexp(-1) = zeros. insert_seq_logits (Tensor) – Insertion sequence log(c). Should have rightmost dimension (M+1, D) and be broadcastable to (batch_size, M+1, D). Should be normalized along the final axis. insert_logits (Tensor) – Insertion probabilities log(r). Should have rightmost dimension (M, 3, 2) and be broadcastable to (batch_size, M, 3, 2). Should be normalized along the final axis. delete_logits (Tensor) – Deletion probabilities log(u). Should have rightmost dimension (M, 3, 2) and be broadcastable to (batch_size, M, 3, 2). Should be normalized along the final axis. substitute_logits (Tensor) – Substitution probabilities log(l). Should have rightmost dimension (D, B), where B is the alphabet size of the data, and broadcastable to (batch_size, D, B). Must be normalized along the final axis. initial_logits, transition_logits, and observation_logits. These parameters can be used to directly initialize the MissingDataDiscreteHMM distribution.
mg2k(m, g, M)[source]

Convert from (m, g) indexing to k indexing.

## Missing or Variable Length Data HMM¶

class MissingDataDiscreteHMM(initial_logits, transition_logits, observation_logits, validate_args=None)[source]

Bases: pyro.distributions.torch_distribution.TorchDistribution

HMM with discrete latent states and discrete observations, allowing for missing data or variable length sequences. Observations are assumed to be one hot encoded; rows with all zeros indicate missing data.

Warning

Unlike in pyro’s pyro.distributions.DiscreteHMM, which computes the probability of the first state as initial.T @ transition @ emission this distribution uses the standard HMM convention, initial.T @ emission

Parameters: initial_logits (Tensor) – A logits tensor for an initial categorical distribution over latent states. Should have rightmost size state_dim and be broadcastable to (batch_size, state_dim). transition_logits (Tensor) – A logits tensor for transition conditional distributions between latent states. Should have rightmost shape (state_dim, state_dim) (old, new), and be broadcastable to (batch_size, state_dim, state_dim). observation_logits (Tensor) – A logits tensor for observation distributions from latent states. Should have rightmost shape (state_dim, categorical_size), where categorical_size is the dimension of the categorical output, and be broadcastable to (batch_size, state_dim, categorical_size).
log_prob(value)[source]
Parameters: value (Tensor) – One-hot encoded observation. Must be real-valued (float) and broadcastable to (batch_size, num_steps, categorical_size) where categorical_size is the dimension of the categorical output. Missing data is represented by zeros, i.e. value[batch, step, :] == tensor([0, ..., 0]). Variable length observation sequences can be handled by padding the sequence with zeros at the end.

class BiosequenceDataset(source, source_type='list', alphabet='amino-acid', max_length=None, include_stop=False, device=None)[source]
Bases: torch.utils.data.dataset.Dataset