Source code for pyro.contrib.mue.dataloaders

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import torch
from torch.utils.data import Dataset

alphabets = {
    "amino-acid": np.array(
        [
            "R",
            "H",
            "K",
            "D",
            "E",
            "S",
            "T",
            "N",
            "Q",
            "C",
            "G",
            "P",
            "A",
            "V",
            "I",
            "L",
            "M",
            "F",
            "Y",
            "W",
        ]
    ),
    "dna": np.array(["A", "C", "G", "T"]),
}


[docs]class BiosequenceDataset(Dataset): """ Load biological sequence data, either from a fasta file or a python list. :param source: Either the input fasta file path (str) or the input list of sequences (list of str). :param str source_type: Type of input, either 'list' or 'fasta'. :param str alphabet: Alphabet to use. Alphabets 'amino-acid' and 'dna' are preset; any other input will be interpreted as the alphabet itself, i.e. you can use 'ACGU' for RNA. :param int max_length: Total length of the one-hot representation of the sequences, including zero padding. Defaults to the maximum sequence length in the dataset. :param bool include_stop: Append stop symbol to the end of each sequence and add the stop symbol to the alphabet. :param torch.device device: Device on which data should be stored in memory. """ def __init__( self, source, source_type="list", alphabet="amino-acid", max_length=None, include_stop=False, device=None, ): super().__init__() # Determine device if device is None: device = torch.tensor(0.0).device self.device = device # Get sequences. self.include_stop = include_stop if source_type == "list": seqs = [seq + include_stop * "*" for seq in source] elif source_type == "fasta": seqs = self._load_fasta(source) # Get lengths. self.L_data = torch.tensor([float(len(seq)) for seq in seqs], device=device) if max_length is None: self.max_length = int(torch.max(self.L_data)) else: self.max_length = max_length self.data_size = len(self.L_data) # Get alphabet. if alphabet in alphabets: alphabet = alphabets[alphabet] else: alphabet = np.array(list(alphabet)) if self.include_stop: alphabet = np.array(list(alphabet) + ["*"]) self.alphabet = alphabet self.alphabet_length = len(alphabet) # Build dataset. self.seq_data = torch.cat( [self._one_hot(seq, alphabet, self.max_length).unsqueeze(0) for seq in seqs] ) def _load_fasta(self, source): """A basic multiline fasta parser.""" seqs = [] seq = "" with open(source, "r") as fr: for line in fr: if line[0] == ">": if seq != "": if self.include_stop: seq += "*" seqs.append(seq) seq = "" else: seq += line.strip("\n") if seq != "": if self.include_stop: seq += "*" seqs.append(seq) return seqs def _one_hot(self, seq, alphabet, length): """One hot encode and pad with zeros to max length.""" # One hot encode. oh = torch.tensor( (np.array(list(seq))[:, None] == alphabet[None, :]).astype(np.float64), device=self.device, ) # Pad. x = torch.cat( [oh, torch.zeros([length - len(seq), len(alphabet)], device=self.device)] ) return x def __len__(self): return self.data_size def __getitem__(self, ind): return (self.seq_data[ind], self.L_data[ind])
[docs]def write(x, alphabet, file, truncate_stop=False, append=False, scores=None): """ Write sequence samples to file. :param ~torch.Tensor x: One-hot encoded sequences, with size ``(data_size, seq_length, alphabet_length)``. May be padded with zeros for variable length sequences. :param ~np.array alphabet: Alphabet. :param str file: Output file, where sequences will be written in fasta format. :param bool truncate_stop: If True, sequences will be truncated at the first stop symbol (i.e. the stop symbol and everything after will not be written). If False, the whole sequence will be written, including any internal stop symbols. :param bool append: If True, sequences are appended to the end of the output file. If False, the file is first erased. """ print_alphabet = np.array(list(alphabet) + [""]) x = torch.cat([x, torch.zeros(list(x.shape[:2]) + [1])], -1) if truncate_stop: mask = ( torch.cumsum( torch.matmul( x, torch.tensor(print_alphabet == "*", dtype=torch.double) ), -1, ) > 0 ).to(torch.double) x = x * (1 - mask).unsqueeze(-1) x[:, :, -1] = mask else: x[:, :, -1] = (torch.sum(x, -1) < 0.5).to(torch.double) index = ( torch.matmul(x, torch.arange(x.shape[-1], dtype=torch.double)) .to(torch.long) .cpu() .numpy() ) if scores is None: seqs = [ ">{}\n".format(j) + "".join(elem) + "\n" for j, elem in enumerate(print_alphabet[index]) ] else: seqs = [ ">{}\n".format(j) + "".join(elem) + "\n" for j, elem in zip(scores, print_alphabet[index]) ] if append: open_flag = "a" else: open_flag = "w" with open(file, open_flag) as fw: fw.write("".join(seqs))