Open RafaelPo opened 3 years ago
Yeah... we are trying that out currently actually. There are a lot of different ways to do it with discrete distributions, did you have one in mind?
Hi,
I was thinking of applying results from: https://arxiv.org/pdf/2002.08676.pdf, recursively on the marginals... do you think that would work?
Yes I think that would be cool. We have some of the papers referenced in that work already implemented, such as differentiable dynamic programming semiring. But it is not exposed in the api. I'm a bit hesistant to call it rsample, because it is biased. Maybe we should have a separate api function that exposes some of these tricks? If you are interested would be happy for a contribution.
Hi,
here is some code I have been playing with:
Nice, that is similar in spirit to this code which we have been working on https://github.com/harvardnlp/pytorch-struct/pull/81 .
We can integrate them both in to the library.
There might also be a way to do this by only calling cvxpy many fewer time.
I will have a look thanks!
How could you save on the number of runs?
also, I think they are supposed to be unbiased, no?
Very neat. So I think that instead of first computing marginals, we can apply this approach in the backward operation of the semiring itself. This is how I compute unbiased gumbel-max samples (https://github.com/harvardnlp/pytorch-struct/pull/81/files#diff-5775ad09d6cfbdc4d52edd6797aba8e68ac66ae04dee680b0e456058bef106dcR70) .
It seems like I can just change this line (https://github.com/harvardnlp/pytorch-struct/pull/81/files#diff-5775ad09d6cfbdc4d52edd6797aba8e68ac66ae04dee680b0e456058bef106dcR71) from an argmax to your CVX code to get a differentiable sample? This should work for all models.
Another advantage of this method is that it will batch across n (our internal code does log n steps instead of n for linear chain).
I agree the forward sample is unbiased, but I will have to read the paper to understand if the gradient is unbiased to? (but I believe you).
Hi,
Not sure how this compares to what you guys have been working on, but for what it's worth I have implemented a version of a biased rsample that uses local gumbel perturbations and temperature-controlled marginals (this is the marginal stochastic softmax trick from https://arxiv.org/abs/2006.08063) directly in the StructDistrubution
class as:
def rsample(self, sample_shape=torch.Size(), temp=1.0, noise_shape=None, sample_batch_size=10):
r"""
Compute structured samples from the _relaxed_ distribution :math:`z \sim p(z;\theta+\gamma, \tau)`
NOTE: These samples are biased.
This uses gumbel perturbations on the potentials followed by the >zero-temp marginals to get approximate samples.
As temp varies from 0 to inf the samples will vary from being exact onehots from an approximate distribution to
a deterministic distribution that is always uniform over all values.
The approximation of the zero-temp limit comes from the fact that we use polynomial (instead of exponential)
perturbations, see:
[Perturb-and-MAP](https://ttic.uchicago.edu/~gpapan/pubs/confr/PapandreouYuille_PerturbAndMap_ieee-c-iccv11.pdf)
[Stochastic Softmax Tricks](https://arxiv.org/abs/2006.08063)
Parameters:
sample_shape (int): number of samples
temp (float): (default=1.0) relaxation temperature
noise_shape (torch.Shape): specify lower-order perturbations by placing ones along any of the potential dims
sample_batch_size (int): size of batches to calculates samples
Returns:
samples (*sample_shape x batch_shape x event_shape*)
"""
# Sanity checks
if type(sample_shape) == int:
nsamples = sample_shape
else:
assert len(sample_shape) == 1
nsamples = sample_shape[0]
if sample_batch_size > nsamples:
sample_batch_size = nsamples
samples = []
if noise_shape is None:
noise_shape = self.log_potentials.shape[1:]
assert len(noise_shape) == len(self.log_potentials.shape[1:])
assert all(
s1 == 1 or s1 == s2 for s1, s2 in zip(noise_shape, self.log_potentials.shape[1:])
), f"Noise shapes must match dimension or be 1: got: {list(zip(noise_shape, self.log_potentials.shape[1:]))}"
# Sampling
for k in range(nsamples):
if k % sample_batch_size == 0:
shape = self.log_potentials.shape
B = shape[0]
s_log_potentials = (
self.log_potentials.reshape(1, *shape)
.repeat(sample_batch_size, *tuple(1 for _ in shape))
.reshape(-1, *shape[1:])
)
s_lengths = self.lengths
if s_lengths is not None:
s_shape = s_lengths.shape
s_lengths = (
s_lengths.reshape(1, *s_shape)
.repeat(sample_batch_size, *tuple(1 for _ in s_shape))
.reshape(-1, *s_shape[1:])
)
noise = (
torch.distributions.Gumbel(0, 1)
.sample((sample_batch_size * B, *noise_shape))
.expand_as(s_log_potentials)
).to(s_log_potentials.device)
noisy_potentials = (s_log_potentials + noise) / temp
r_sample = (
self._struct(LogSemiring)
.marginals(noisy_potentials, s_lengths)
.reshape(sample_batch_size, B, *shape[1:])
)
samples.append(r_sample)
return torch.cat(samples, dim=0)[:nsamples]
Let me know if you'd like me to submit as a pr (with whatever changes you think make sense).
Thanks, Tom
Awesome sounds like we have three different methods. The one in my PR is from Yao's NeurIPS work https://arxiv.org/abs/2011.14244 which is unbiased forward and biased backward. Maybe we should have a phone call and figure out the differences and how to document and compare them.
Very interesting, I'll take a look at the paper -- unbiased forward sounds like a big plus. I'm available for a call to discuss pretty much whenever.
Not sure how what I proposed compares to the rest, it seems (way) more computationally expensive but I would be interested in a call as well, but I am based in England.
Are there plans to introduce differentiable samples?
Thanks!