Source code for pyro.contrib.examples.bart

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import argparse
import csv
import datetime
import logging
import multiprocessing
import os
import subprocess
import sys
import urllib

import torch

from pyro.contrib.examples.util import get_data_directory

DATA = get_data_directory(__file__)

# https://www.bart.gov/about/reports/ridership
SOURCE_DIR = "http://64.111.127.166/origin-destination/"
SOURCE_FILES = [
    "date-hour-soo-dest-2011.csv.gz",
    "date-hour-soo-dest-2012.csv.gz",
    "date-hour-soo-dest-2013.csv.gz",
    "date-hour-soo-dest-2014.csv.gz",
    "date-hour-soo-dest-2015.csv.gz",
    "date-hour-soo-dest-2016.csv.gz",
    "date-hour-soo-dest-2017.csv.gz",
    "date-hour-soo-dest-2018.csv.gz",
]
CACHE_URL = "https://d2hg8soec8ck9v.cloudfront.net/datasets/bart_full.pkl.bz2"


def _mkdir_p(dirname):
    if not os.path.exists(dirname):
        try:
            os.makedirs(dirname)
        except FileExistsError:
            pass


def _load_hourly_od(basename):
    filename = os.path.join(DATA, basename.replace(".csv.gz", ".pkl"))
    if os.path.exists(filename):
        return filename

    # Download source files.
    _mkdir_p(DATA)
    gz_filename = os.path.join(DATA, basename)
    if not os.path.exists(gz_filename):
        url = SOURCE_DIR + basename
        logging.debug("downloading {}".format(url))
        urllib.request.urlretrieve(url, gz_filename)
    csv_filename = gz_filename[:-3]
    assert csv_filename.endswith(".csv")
    if not os.path.exists(csv_filename):
        logging.debug("unzipping {}".format(gz_filename))
        subprocess.check_call(["gunzip", "-k", gz_filename])
    assert os.path.exists(csv_filename)

    # Convert to PyTorch.
    logging.debug("converting {}".format(csv_filename))
    start_date = datetime.datetime.strptime("2000-01-01", "%Y-%m-%d")
    stations = {}
    num_rows = sum(1 for _ in open(csv_filename))
    logging.info("Formatting {} rows".format(num_rows))
    rows = torch.empty((num_rows, 4), dtype=torch.long)
    with open(csv_filename) as f:
        for i, (date, hour, origin, destin, trip_count) in enumerate(csv.reader(f)):
            date = datetime.datetime.strptime(date, "%Y-%m-%d")
            date += datetime.timedelta(hours=int(hour))
            rows[i, 0] = int((date - start_date).total_seconds() / 3600)
            rows[i, 1] = stations.setdefault(origin, len(stations))
            rows[i, 2] = stations.setdefault(destin, len(stations))
            rows[i, 3] = int(trip_count)
            if i % 10000 == 0:
                sys.stderr.write(".")
                sys.stderr.flush()

    # Save data with metadata.
    dataset = {
        "basename": basename,
        "start_date": start_date,
        "stations": stations,
        "rows": rows,
        "schema": ["time_hours", "origin", "destin", "trip_count"],
    }
    dataset["rows"]
    logging.debug("saving {}".format(filename))
    torch.save(dataset, filename)
    return filename


[docs]def load_bart_od(): """ Load a dataset of hourly origin-destination ridership counts for every pair of BART stations during the years 2011-2018. **Source** https://www.bart.gov/about/reports/ridership This downloads and preprocesses the dataset the first time it is called, requiring about 300MB of file transfer and storing a few GB of temp files. On subsequent calls this reads from a cached ``.pkl.bz2``. :returns: a dataset is a dictionary with fields: - "stations": a list of strings of station names - "start_date": a :py:class:`datetime.datetime` for the first observaion - "counts": a ``torch.FloatTensor`` of ridership counts, with shape ``(num_hours, len(stations), len(stations))``. """ filename = os.path.join(DATA, "bart_full.pkl.bz2") # Work around apparent bug in torch.load(),torch.save(). pkl_file = filename.rsplit(".", 1)[0] if not os.path.exists(pkl_file): try: urllib.request.urlretrieve(CACHE_URL, filename) logging.debug("cache hit, uncompressing") subprocess.check_call(["bunzip2", "-k", filename]) except urllib.error.HTTPError: logging.debug("cache miss, preprocessing from scratch") if os.path.exists(pkl_file): return torch.load(pkl_file) filenames = multiprocessing.Pool().map(_load_hourly_od, SOURCE_FILES) datasets = list(map(torch.load, filenames)) stations = sorted(set().union(*(d["stations"].keys() for d in datasets))) min_time = min(int(d["rows"][:, 0].min()) for d in datasets) max_time = max(int(d["rows"][:, 0].max()) for d in datasets) num_rows = max_time - min_time + 1 start_date = datasets[0]["start_date"] + datetime.timedelta(hours=min_time), logging.info("Loaded data from {} stations, {} hours" .format(len(stations), num_rows)) result = torch.zeros(num_rows, len(stations), len(stations)) for dataset in datasets: part_stations = sorted(dataset["stations"], key=dataset["stations"].__getitem__) part_to_whole = torch.tensor(list(map(stations.index, part_stations))) time = dataset["rows"][:, 0] - min_time origin = part_to_whole[dataset["rows"][:, 1]] destin = part_to_whole[dataset["rows"][:, 2]] count = dataset["rows"][:, 3].float() result[time, origin, destin] = count dataset.clear() logging.info("Loaded {} shaped data of mean {:0.3g}" .format(result.shape, result.mean())) dataset = { "stations": stations, "start_date": start_date, "counts": result, } torch.save(dataset, pkl_file) subprocess.check_call(["bzip2", "-k", pkl_file]) assert os.path.exists(filename) return dataset
[docs]def load_fake_od(): """ Create a tiny synthetic dataset for smoke testing. """ dataset = { "stations": ["12TH", "EMBR", "SFIA"], "start_date": datetime.datetime.strptime("2000-01-01", "%Y-%m-%d"), "counts": torch.distributions.Poisson(100).sample([24 * 7 * 8, 3, 3]), } return dataset
if __name__ == "__main__": parser = argparse.ArgumentParser(description="BART data preprocessor") parser.add_argument("-v", "--verbose", action="store_true") args = parser.parse_args() logging.basicConfig(format='%(relativeCreated) 9d %(message)s', level=logging.DEBUG if args.verbose else logging.INFO) load_bart_od()