Source code for pyro.contrib.examples.util

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import os
import sys

import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torchvision import transforms


[docs]class MNIST(datasets.MNIST): mirrors = ["https://github.com/pyro-ppl/datasets/blob/master/mnist/"]
[docs] def download(self) -> None: """Download the MNIST data if it doesn't exist already.""" if self._check_exists(): return os.makedirs(self.raw_folder, exist_ok=True) # download files for filename, md5 in self.resources: errors = [] for mirror in self.mirrors: url = f"{mirror}{filename}?raw=true" try: datasets.utils.download_and_extract_archive( url, download_root=self.raw_folder, filename=filename, md5=md5 ) except datasets.URLError as e: errors.append(e) continue break else: s = f"Error downloading {filename}:\n" for mirror, err in zip(self.mirrors, errors): s += f"Tried {mirror}, got:\n{str(err)}\n" raise RuntimeError(s)
[docs]def get_data_loader( dataset_name, data_dir, batch_size=1, dataset_transforms=None, is_training_set=True, shuffle=True, ): if not dataset_transforms: dataset_transforms = [] trans = transforms.Compose([transforms.ToTensor()] + dataset_transforms) if dataset_name == "MNIST": dataset = MNIST else: dataset = getattr(datasets, dataset_name) print("downloading data") dset = dataset(root=data_dir, train=is_training_set, transform=trans, download=True) print("download complete.") return DataLoader(dset, batch_size=batch_size, shuffle=shuffle)
[docs]def get_data_directory(filepath=None): if "CI" in os.environ: return os.path.expanduser("~/.data") return os.path.abspath(os.path.join(os.path.dirname(filepath), ".data"))
def _mkdir_p(dirname): if not os.path.exists(dirname): try: os.makedirs(dirname) except FileExistsError: pass