# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
"""
This script generates a dataset similar to the Multi-MNIST dataset
described in [1].
[1] Eslami, SM Ali, et al. "Attend, infer, repeat: Fast scene
understanding with generative models." Advances in Neural Information
Processing Systems. 2016.
"""
import os
import numpy as np
from PIL import Image
from pyro.contrib.examples.util import get_data_loader
[docs]def imresize(arr, size):
return np.array(Image.fromarray(arr).resize(size))
[docs]def sample_one(canvas_size, mnist):
i = np.random.randint(mnist["digits"].shape[0])
digit = mnist["digits"][i]
label = mnist["labels"][i].item()
scale = 0.1 * np.random.randn() + 1.3
new_size = tuple(int(s / scale) for s in digit.shape)
resized = imresize(digit, new_size)
w = resized.shape[0]
assert w == resized.shape[1]
padding = canvas_size - w
pad_l = np.random.randint(0, padding)
pad_r = np.random.randint(0, padding)
pad_width = ((pad_l, padding - pad_l), (pad_r, padding - pad_r))
positioned = np.pad(resized, pad_width, "constant", constant_values=0)
return positioned, label
[docs]def sample_multi(num_digits, canvas_size, mnist):
canvas = np.zeros((canvas_size, canvas_size))
labels = []
for _ in range(num_digits):
positioned_digit, label = sample_one(canvas_size, mnist)
canvas += positioned_digit
labels.append(label)
# Crude check for overlapping digits.
if np.max(canvas) > 255:
return sample_multi(num_digits, canvas_size, mnist)
else:
return canvas, labels
[docs]def mk_dataset(n, mnist, max_digits, canvas_size):
x = []
y = []
for _ in range(n):
num_digits = np.random.randint(max_digits + 1)
canvas, labels = sample_multi(num_digits, canvas_size, mnist)
x.append(canvas)
y.append(labels)
return np.array(x, dtype=np.uint8), np.array(y, dtype=object)
[docs]def load_mnist(root_path):
loader = get_data_loader("MNIST", root_path)
return {
"digits": loader.dataset.data.cpu().numpy(),
"labels": loader.dataset.targets,
}
[docs]def load(root_path):
file_path = os.path.join(root_path, "multi_mnist_uint8.npz")
if os.path.exists(file_path):
data = np.load(file_path, allow_pickle=True)
return data["x"], data["y"]
else:
# Set RNG to known state.
rng_state = np.random.get_state()
np.random.seed(681307)
mnist = load_mnist(root_path)
print("Generating multi-MNIST dataset...")
x, y = mk_dataset(60000, mnist, 2, 50)
# Revert RNG state.
np.random.set_state(rng_state)
# Crude checksum.
# assert x.sum() == 883114919, 'Did not generate the expected data.'
with open(file_path, "wb") as f:
np.savez_compressed(f, x=x, y=y)
print("Done!")
return x, y