Open mtvector opened 12 months ago
FYI: I also took a stab at fixing the straightthroughcategorical, this could still use some work but it works for me where the previous RelaxedCategoricalStraightThrough would not train as part of an GMM-VAE
class RelaxedQuantizeCategorical(torch.autograd.Function):
temperature = None # Default temperature
epsilon = 1e-10 # Default epsilon
@staticmethod
def set_temperature(new_temperature):
RelaxedQuantizeCategorical.temperature = new_temperature
@staticmethod
def set_epsilon(new_epsilon):
RelaxedQuantizeCategorical.epsilon = new_epsilon
@staticmethod
def forward(ctx, soft_value):
temperature = float(RelaxedQuantizeCategorical.temperature)
epsilon = RelaxedQuantizeCategorical.epsilon
uniforms = clamp_probs(
torch.rand(soft_value.shape, dtype=soft_value.dtype, device=soft_value.device)
)
gumbels = -((-(uniforms.log())).log())
scores = (soft_value + gumbels) / temperature
outs = scores - scores.logsumexp(dim=-1, keepdim=True)
outs = outs.exp()
outs = outs + epsilon # Use the class variable epsilon
hard_value = (outs / outs.sum(1, keepdim=True)).log()
hard_value._unquantize = soft_value
return hard_value
@staticmethod
def backward(ctx, grad):
return grad
class ExpRelaxedCategoricalStraightThrough(Distribution):
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
support = (
constraints.real_vector
) # The true support is actually a submanifold of this.
has_rsample = True
def __init__(self, temperature, probs=None, logits=None, validate_args=None, epsilon=1e-10):
self._categorical = Categorical(probs, logits)
self.temperature = temperature
RelaxedQuantizeCategorical.set_temperature(temperature)
RelaxedQuantizeCategorical.set_epsilon(epsilon)
batch_shape = self._categorical.batch_shape
event_shape = self._categorical.param_shape[-1:]
super().__init__(batch_shape, event_shape, validate_args=validate_args)
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(ExpRelaxedCategorical, _instance)
batch_shape = torch.Size(batch_shape)
new.temperature = self.temperature
new._categorical = self._categorical.expand(batch_shape)
super(ExpRelaxedCategorical, new).__init__(
batch_shape, self.event_shape, validate_args=False
)
new._validate_args = self._validate_args
return new
def _new(self, *args, **kwargs):
return self._categorical._new(*args, **kwargs)
@property
def param_shape(self):
return self._categorical.param_shape
@property
def logits(self):
return self._categorical.logits
@property
def probs(self):
return self._categorical.probs
def rsample(self, sample_shape=torch.Size()):
outs=RelaxedQuantizeCategorical.apply(self.logits)
return outs
def log_prob(self, value):
value = getattr(value, "_unquantize", value)
K = self._categorical._num_events
if self._validate_args:
self._validate_sample(value)
logits, value = broadcast_all(self.logits, value)
score = logits
score = (score - score.logsumexp(dim=-1, keepdim=True)).sum(-1)
return score
class SafeAndRelaxedOneHotCategoricalStraightThrough(TransformedDistribution,TorchDistributionMixin):
#Don't understand why these were broken (doesn't call straighthrough rsample in pyro)?
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
support = constraints.simplex
has_rsample = True
def __init__(self, temperature, probs=None, logits=None, validate_args=None):
base_dist = ExpRelaxedCategoricalStraightThrough(
temperature, probs, logits, validate_args=validate_args
)
super().__init__(base_dist, ExpTransform(), validate_args=validate_args)
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(RelaxedOneHotCategorical, _instance)
return super().expand(batch_shape, _instance=new)
@property
def temperature(self):
return self.base_dist.temperature
@property
def logits(self):
return self.base_dist.logits
@property
def probs(self):
return self.base_dist.probs
Hi @mtvector, I think our general design principle with distributions is to make them hackable with decent defaults. In this case I'd lean towards letting users add their own epsilon in a custom distribution class. In my own projects I often have one or two custom distributions for each data science project. What do you think of a simple patched distribution, just for your project?
from pyro.distributions import ExpRelaxedCategorical
class SafeExpRelaxedCategorical(ExpRelaxedCategorical):
epsilon = 1e-10
def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
uniforms = clamp_probs(
torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device)
)
gumbels = -((-(uniforms.log())).log())
scores = (self.logits + gumbels) / self.temperature
#could also clamp_probs
outs = scores - scores.logsumexp(dim=-1, keepdim=True)
outs = outs.exp()
outs = outs + self.epsilon # prevent underflow
outs = (outs / outs.sum(1, keepdim=True)).log()
return outs
Actually I often find that (1) clamping is safer than adding, and (2) it's best to use torch.finfo(-).tiny
rather than a hard-coded epsilon. So you might customize
class SafeExpRelaxedCategorical2(ExpRelaxedCategorical):
def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
uniforms = clamp_probs(
torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device)
)
gumbels = -((-(uniforms.log())).log())
scores = (self.logits + gumbels) / self.temperature
#could also clamp_probs
outs = scores - scores.logsumexp(dim=-1, keepdim=True)
outs = outs.exp()
outs = outs.clamp(min=torch.finfo(outs.dtype).tiny)
outs = (outs / outs.sum(1, keepdim=True)).log()
return outs
WDYT?
Hi @fritzo, I agree in principle, you're right about the the hackability as well as using the proper epsilon or torch tiny, still working on my coding modularity :). I do think it's important to fix the default though, I used pyro for two years and thought the RelaxedCategorical was totally unusable because it seems to fail in the following:
import pyro
import torch
import pyro.distributions as dist
def model(logits):
pyro.sample('cat_sample',dist.RelaxedOneHotCategorical(temperature=0.1*torch.ones(1),
logits=torch.zeros(1,1000)))
def guide(logits):
pyro.sample('cat_sample',dist.RelaxedOneHotCategorical(temperature=0.1*torch.ones(1),
logits=logits))
pyro.clear_param_store()
optim = pyro.optim.Adam({"lr": 0.1})
elbo = pyro.infer.Trace_ELBO()
svi = pyro.infer.SVI(model, guide, optim, loss=elbo)
for i in range(10):
logits=torch.randn(1,1000)
loss = svi.step(logits)
Giving the error due to underflow:
warn_if_nan(
.../pyro/lib/python3.11/site-packages/pyro/poutine/trace_struct.py:285: UserWarning: Encountered NaN: log_prob_sum at site 'cat_sample'
You're right about the fix, for instance your first resolves the issue with the underflow in a more elegant way than what I proposed:
import pyro.distributions
from torch.distributions.relaxed_categorical import ExpRelaxedCategorical
from pyro.distributions.torch_distribution import TorchDistributionMixin
from pyro.distributions import TransformedDistribution
class SafeExpRelaxedCategorical(ExpRelaxedCategorical):
epsilon = 1e-10
def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
uniforms = clamp_probs(
torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device)
)
gumbels = -((-(uniforms.log())).log())
scores = (self.logits + gumbels) / self.temperature
#could also clamp_probs
outs = scores - scores.logsumexp(dim=-1, keepdim=True)
outs = outs.exp()
outs = outs + self.epsilon # prevent underflow
outs = (outs / outs.sum(1, keepdim=True)).log()
return outs
class SafeRelaxedOneHotCategorical(TransformedDistribution,TorchDistributionMixin):
r"""
Creates a RelaxedOneHotCategorical distribution parametrized by
:attr:`temperature`, and either :attr:`probs` or :attr:`logits`.
This is a relaxed version of the :class:`OneHotCategorical` distribution, so
its samples are on simplex, and are reparametrizable.
Example::
>>> # xdoctest: +IGNORE_WANT("non-deterinistic")
>>> m = RelaxedOneHotCategorical(torch.tensor([2.2]),
... torch.tensor([0.1, 0.2, 0.3, 0.4]))
>>> m.sample()
tensor([ 0.1294, 0.2324, 0.3859, 0.2523])
Args:
temperature (Tensor): relaxation temperature
probs (Tensor): event probabilities
logits (Tensor): unnormalized log probability for each event
"""
arg_constraints = {'probs': constraints.simplex,
'logits': constraints.real_vector}
support = constraints.simplex
has_rsample = True
def __init__(self, temperature, probs=None, logits=None, validate_args=None):
base_dist = SafeExpRelaxedCategorical(temperature, probs, logits, validate_args=validate_args)
super().__init__(base_dist, ExpTransform(), validate_args=validate_args)
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(RelaxedOneHotCategorical, _instance)
return super().expand(batch_shape, _instance=new)
@property
def temperature(self):
return self.base_dist.temperature
@property
def logits(self):
return self.base_dist.logits
@property
def probs(self):
return self.base_dist.probs
def model(logits):
pyro.sample('cat_sample',SafeRelaxedOneHotCategorical(temperature=0.1*torch.ones(1),
logits=torch.ones(1,1000)))
def guide(logits):
pyro.sample('cat_sample',SafeRelaxedOneHotCategorical(temperature=0.1*torch.ones(1),
logits=logits))
pyro.clear_param_store()
optim = pyro.optim.Adam({"lr": 0.1})
elbo = pyro.infer.Trace_ELBO()
svi = pyro.infer.SVI(model, guide, optim, loss=elbo)
for i in range(10):
logits=torch.randn(1,1000)
loss = svi.step(logits)
Which gives no error, like my SafeAndRelaxedOneHotCategoricalStraightThrough above
So, yeah, it seems like the default for RelaxedOneHotCategorical should use one of these SafeExpRelaxedCategorical bases you've proposed here?
I've noticed that pyro.distributions.RelaxedOneHotCategorical tends to underflow pretty dramatically if you decrease the temperature below 0.3 or so with many categories. I've been adding a slight modification to the rsample function of the ExpRelaxedCategorical class it's built on. Just wanted to post this in case you want to consider this (maybe hacky) fix to make this distribution work with pyro support constraints.
modified from here https://github.com/pytorch/pytorch/blob/main/torch/distributions/relaxed_categorical.py :