Pyro Examples

Datasets

Multi MNIST

This script generates a dataset similar to the Multi-MNIST dataset described in [1].

[1] Eslami, SM Ali, et al. “Attend, infer, repeat: Fast scene understanding with generative models.” Advances in Neural Information Processing Systems. 2016.

imresize(arr, size)[source]
sample_one(canvas_size, mnist)[source]
sample_multi(num_digits, canvas_size, mnist)[source]
mk_dataset(n, mnist, max_digits, canvas_size)[source]
load_mnist(root_path)[source]
load(root_path)[source]

BART Ridership

load_bart_od()[source]

Load a dataset of hourly origin-destination ridership counts for every pair of BART stations during the years 2011-2018.

Source https://www.bart.gov/about/reports/ridership

This downloads and preprocesses the dataset the first time it is called, requiring about 300MB of file transfer and storing a few GB of temp files. On subsequent calls this reads from a cached .pkl.bz2.

Returns:a dataset is a dictionary with fields:
  • ”stations”: a list of strings of station names
  • ”start_date”: a datetime.datetime for the first observaion
  • ”counts”: a torch.FloatTensor of ridership counts, with shape (num_hours, len(stations), len(stations)).
load_fake_od()[source]

Create a tiny synthetic dataset for smoke testing.

Utilities

get_data_loader(dataset_name, data_dir, batch_size=1, dataset_transforms=None, is_training_set=True, shuffle=True)[source]
print_and_log(logger, msg)[source]
get_data_directory(filepath=None)[source]