# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import argparse
import csv
import datetime
import logging
import multiprocessing
import os
import subprocess
import sys
import urllib
import torch
from pyro.contrib.examples.util import get_data_directory
DATA = get_data_directory(__file__)
# https://www.bart.gov/about/reports/ridership
SOURCE_DIR = "http://64.111.127.166/origin-destination/"
SOURCE_FILES = [
"date-hour-soo-dest-2011.csv.gz",
"date-hour-soo-dest-2012.csv.gz",
"date-hour-soo-dest-2013.csv.gz",
"date-hour-soo-dest-2014.csv.gz",
"date-hour-soo-dest-2015.csv.gz",
"date-hour-soo-dest-2016.csv.gz",
"date-hour-soo-dest-2017.csv.gz",
"date-hour-soo-dest-2018.csv.gz",
]
CACHE_URL = "https://d2hg8soec8ck9v.cloudfront.net/datasets/bart_full.pkl.bz2"
def _mkdir_p(dirname):
if not os.path.exists(dirname):
try:
os.makedirs(dirname)
except FileExistsError:
pass
def _load_hourly_od(basename):
filename = os.path.join(DATA, basename.replace(".csv.gz", ".pkl"))
if os.path.exists(filename):
return filename
# Download source files.
_mkdir_p(DATA)
gz_filename = os.path.join(DATA, basename)
if not os.path.exists(gz_filename):
url = SOURCE_DIR + basename
logging.debug("downloading {}".format(url))
urllib.request.urlretrieve(url, gz_filename)
csv_filename = gz_filename[:-3]
assert csv_filename.endswith(".csv")
if not os.path.exists(csv_filename):
logging.debug("unzipping {}".format(gz_filename))
subprocess.check_call(["gunzip", "-k", gz_filename])
assert os.path.exists(csv_filename)
# Convert to PyTorch.
logging.debug("converting {}".format(csv_filename))
start_date = datetime.datetime.strptime("2000-01-01", "%Y-%m-%d")
stations = {}
num_rows = sum(1 for _ in open(csv_filename))
logging.info("Formatting {} rows".format(num_rows))
rows = torch.empty((num_rows, 4), dtype=torch.long)
with open(csv_filename) as f:
for i, (date, hour, origin, destin, trip_count) in enumerate(csv.reader(f)):
date = datetime.datetime.strptime(date, "%Y-%m-%d")
date += datetime.timedelta(hours=int(hour))
rows[i, 0] = int((date - start_date).total_seconds() / 3600)
rows[i, 1] = stations.setdefault(origin, len(stations))
rows[i, 2] = stations.setdefault(destin, len(stations))
rows[i, 3] = int(trip_count)
if i % 10000 == 0:
sys.stderr.write(".")
sys.stderr.flush()
# Save data with metadata.
dataset = {
"basename": basename,
"start_date": start_date,
"stations": stations,
"rows": rows,
"schema": ["time_hours", "origin", "destin", "trip_count"],
}
dataset["rows"]
logging.debug("saving {}".format(filename))
torch.save(dataset, filename)
return filename
[docs]def load_bart_od():
"""
Load a dataset of hourly origin-destination ridership counts for every pair
of BART stations during the years 2011-2018.
**Source** https://www.bart.gov/about/reports/ridership
This downloads and preprocesses the dataset the first time it is called,
requiring about 300MB of file transfer and storing a few GB of temp files.
On subsequent calls this reads from a cached ``.pkl.bz2``.
:returns: a dataset is a dictionary with fields:
- "stations": a list of strings of station names
- "start_date": a :py:class:`datetime.datetime` for the first observaion
- "counts": a ``torch.FloatTensor`` of ridership counts, with shape
``(num_hours, len(stations), len(stations))``.
"""
filename = os.path.join(DATA, "bart_full.pkl.bz2")
# Work around apparent bug in torch.load(),torch.save().
pkl_file = filename.rsplit(".", 1)[0]
if not os.path.exists(pkl_file):
try:
urllib.request.urlretrieve(CACHE_URL, filename)
logging.debug("cache hit, uncompressing")
subprocess.check_call(["bunzip2", "-k", filename])
except urllib.error.HTTPError:
logging.debug("cache miss, preprocessing from scratch")
if os.path.exists(pkl_file):
return torch.load(pkl_file)
filenames = multiprocessing.Pool().map(_load_hourly_od, SOURCE_FILES)
datasets = list(map(torch.load, filenames))
stations = sorted(set().union(*(d["stations"].keys() for d in datasets)))
min_time = min(int(d["rows"][:, 0].min()) for d in datasets)
max_time = max(int(d["rows"][:, 0].max()) for d in datasets)
num_rows = max_time - min_time + 1
start_date = datasets[0]["start_date"] + datetime.timedelta(hours=min_time),
logging.info("Loaded data from {} stations, {} hours"
.format(len(stations), num_rows))
result = torch.zeros(num_rows, len(stations), len(stations))
for dataset in datasets:
part_stations = sorted(dataset["stations"], key=dataset["stations"].__getitem__)
part_to_whole = torch.tensor(list(map(stations.index, part_stations)))
time = dataset["rows"][:, 0] - min_time
origin = part_to_whole[dataset["rows"][:, 1]]
destin = part_to_whole[dataset["rows"][:, 2]]
count = dataset["rows"][:, 3].float()
result[time, origin, destin] = count
dataset.clear()
logging.info("Loaded {} shaped data of mean {:0.3g}"
.format(result.shape, result.mean()))
dataset = {
"stations": stations,
"start_date": start_date,
"counts": result,
}
torch.save(dataset, pkl_file)
subprocess.check_call(["bzip2", "-k", pkl_file])
assert os.path.exists(filename)
return dataset
[docs]def load_fake_od():
"""
Create a tiny synthetic dataset for smoke testing.
"""
dataset = {
"stations": ["12TH", "EMBR", "SFIA"],
"start_date": datetime.datetime.strptime("2000-01-01", "%Y-%m-%d"),
"counts": torch.distributions.Poisson(100).sample([24 * 7 * 8, 3, 3]),
}
return dataset
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="BART data preprocessor")
parser.add_argument("-v", "--verbose", action="store_true")
args = parser.parse_args()
logging.basicConfig(format='%(relativeCreated) 9d %(message)s',
level=logging.DEBUG if args.verbose else logging.INFO)
load_bart_od()