Source code for pyro.contrib.examples.nextstrain

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import logging
import os
import subprocess
import urllib

import torch

from .util import _mkdir_p, get_data_directory

DATA = get_data_directory(__file__)
URL = "https://github.com/pyro-ppl/datasets/raw/master/nextstrain.data.pt.gz"


[docs]def load_nextstrain_counts(map_location=None) -> dict: """ Loads a SARS-CoV-2 dataset. The original dataset is a preprocessed intermediate ``metadata.tsv.gz`` available via `nextstrain <https://docs.nextstrain.org/projects/ncov/en/latest/reference/remote_inputs.html>`_. The ``metadata.tsv.gz`` file was then aggregated to (month,location,lineage) and (lineage,mutation) bins by the Broad Institute's `preprocessing script <https://github.com/broadinstitute/pyro-cov/blob/master/scripts/preprocess_nextstrain.py>`_. """ # Download the gzipped file. _mkdir_p(DATA) basename = URL.split("/")[-1] gz_filename = os.path.join(DATA, basename) if not os.path.exists(gz_filename): logging.debug(f"downloading {URL}") urllib.request.urlretrieve(URL, gz_filename) # Decompress the file. filename = gz_filename.replace(".gz", "") if not os.path.exists(filename): logging.debug(f"unzipping {gz_filename}") subprocess.check_call(["gunzip", "-k", gz_filename]) # Load tensors to the default location. if map_location is None: map_location = torch.tensor(0.0).device return torch.load(filename, map_location=map_location, weights_only=False)