Source code for pyro.contrib.examples.util

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import os
import sys

import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torchvision import transforms


[docs]class MNIST(datasets.MNIST): # For older torchvision. urls = [ "https://d2hg8soec8ck9v.cloudfront.net/datasets/mnist/train-images-idx3-ubyte.gz", "https://d2hg8soec8ck9v.cloudfront.net/datasets/mnist/train-labels-idx1-ubyte.gz", "https://d2hg8soec8ck9v.cloudfront.net/datasets/mnist/t10k-images-idx3-ubyte.gz", "https://d2hg8soec8ck9v.cloudfront.net/datasets/mnist/t10k-labels-idx1-ubyte.gz", ] # For newer torchvision. resources = list(zip(urls, [ "f68b3c2dcbeaaa9fbdd348bbdeb94873", "d53e105ee54ea40749a09fcbcd1e9432", "9fb629c4189551a2d022fa330f9573f3", "ec29112dd5afa0611ce80d1b7f02629c" ]))
[docs]def get_data_loader(dataset_name, data_dir, batch_size=1, dataset_transforms=None, is_training_set=True, shuffle=True): if not dataset_transforms: dataset_transforms = [] trans = transforms.Compose([transforms.ToTensor()] + dataset_transforms) if dataset_name == "MNIST": dataset = MNIST else: dataset = getattr(datasets, dataset_name) print("downloading data") dset = dataset(root=data_dir, train=is_training_set, transform=trans, download=True) print("download complete.") return DataLoader( dset, batch_size=batch_size, shuffle=shuffle )
[docs]def get_data_directory(filepath=None): if 'CI' in os.environ: return os.path.expanduser('~/.data') return os.path.abspath(os.path.join(os.path.dirname(filepath), '.data'))
def _mkdir_p(dirname): if not os.path.exists(dirname): try: os.makedirs(dirname) except FileExistsError: pass