import torch
from pyro.poutine.util import is_validation_enabled
from .messenger import Messenger
[docs]class ScaleMessenger(Messenger):
"""
This messenger rescales the log probability score.
This is typically used for data subsampling or for stratified sampling of data
(e.g. in fraud detection where negatives vastly outnumber positives).
:param scale: a positive scaling factor
:type scale: float or torch.Tensor
"""
def __init__(self, scale):
if isinstance(scale, torch.Tensor):
if is_validation_enabled() and not (scale > 0).all():
raise ValueError("Expected scale > 0 but got {}. ".format(scale) +
"Consider using poutine.mask() instead of poutine.scale().")
elif not (scale > 0):
raise ValueError("Expected scale > 0 but got {}".format(scale))
super(ScaleMessenger, self).__init__()
self.scale = scale
def _process_message(self, msg):
msg["scale"] = self.scale * msg["scale"]
return None