harvardnlp / pytorch-struct

Fast, general, and tested differentiable structured prediction in PyTorch
http://harvardnlp.github.io/pytorch-struct
MIT License
1.11k stars 93 forks source link

differentiable samples (rsample) #82

Open RafaelPo opened 3 years ago

RafaelPo commented 3 years ago

Are there plans to introduce differentiable samples?

Thanks!

srush commented 3 years ago

Yeah... we are trying that out currently actually. There are a lot of different ways to do it with discrete distributions, did you have one in mind?

RafaelPo commented 3 years ago

Hi,

I was thinking of applying results from: https://arxiv.org/pdf/2002.08676.pdf, recursively on the marginals... do you think that would work?

srush commented 3 years ago

Yes I think that would be cool. We have some of the papers referenced in that work already implemented, such as differentiable dynamic programming semiring. But it is not exposed in the api. I'm a bit hesistant to call it rsample, because it is biased. Maybe we should have a separate api function that exposes some of these tricks? If you are interested would be happy for a contribution.

RafaelPo commented 3 years ago

Hi,

here is some code I have been playing with: image image

srush commented 3 years ago

Nice, that is similar in spirit to this code which we have been working on https://github.com/harvardnlp/pytorch-struct/pull/81 .

We can integrate them both in to the library.

There might also be a way to do this by only calling cvxpy many fewer time.

RafaelPo commented 3 years ago

I will have a look thanks!

How could you save on the number of runs?

also, I think they are supposed to be unbiased, no?

srush commented 3 years ago

Very neat. So I think that instead of first computing marginals, we can apply this approach in the backward operation of the semiring itself. This is how I compute unbiased gumbel-max samples (https://github.com/harvardnlp/pytorch-struct/pull/81/files#diff-5775ad09d6cfbdc4d52edd6797aba8e68ac66ae04dee680b0e456058bef106dcR70) .

It seems like I can just change this line (https://github.com/harvardnlp/pytorch-struct/pull/81/files#diff-5775ad09d6cfbdc4d52edd6797aba8e68ac66ae04dee680b0e456058bef106dcR71) from an argmax to your CVX code to get a differentiable sample? This should work for all models.

Another advantage of this method is that it will batch across n (our internal code does log n steps instead of n for linear chain).

I agree the forward sample is unbiased, but I will have to read the paper to understand if the gradient is unbiased to? (but I believe you).

teffland commented 3 years ago

Hi,

Not sure how this compares to what you guys have been working on, but for what it's worth I have implemented a version of a biased rsample that uses local gumbel perturbations and temperature-controlled marginals (this is the marginal stochastic softmax trick from https://arxiv.org/abs/2006.08063) directly in the StructDistrubution class as:

def rsample(self, sample_shape=torch.Size(), temp=1.0, noise_shape=None, sample_batch_size=10):
        r"""
        Compute structured samples from the _relaxed_ distribution :math:`z \sim p(z;\theta+\gamma, \tau)`

        NOTE: These samples are biased.

        This uses gumbel perturbations on the potentials followed by the >zero-temp marginals to get approximate samples.
        As temp varies from 0 to inf the samples will vary from being exact onehots from an approximate distribution to
        a deterministic distribution that is always uniform over all values.

        The approximation of the zero-temp limit comes from the fact that we use polynomial (instead of exponential)
        perturbations, see:
          [Perturb-and-MAP](https://ttic.uchicago.edu/~gpapan/pubs/confr/PapandreouYuille_PerturbAndMap_ieee-c-iccv11.pdf)
          [Stochastic Softmax Tricks](https://arxiv.org/abs/2006.08063)

        Parameters:
            sample_shape (int): number of samples
            temp (float): (default=1.0) relaxation temperature
            noise_shape (torch.Shape): specify lower-order perturbations by placing ones along any of the potential dims
            sample_batch_size (int): size of batches to calculates samples

        Returns:
            samples (*sample_shape x batch_shape x event_shape*)

        """
        # Sanity checks
        if type(sample_shape) == int:
            nsamples = sample_shape
        else:
            assert len(sample_shape) == 1
            nsamples = sample_shape[0]
        if sample_batch_size > nsamples:
            sample_batch_size = nsamples
        samples = []

        if noise_shape is None:
            noise_shape = self.log_potentials.shape[1:]

        assert len(noise_shape) == len(self.log_potentials.shape[1:])
        assert all(
            s1 == 1 or s1 == s2 for s1, s2 in zip(noise_shape, self.log_potentials.shape[1:])
        ), f"Noise shapes must match dimension or be 1: got: {list(zip(noise_shape, self.log_potentials.shape[1:]))}"

        # Sampling
        for k in range(nsamples):
            if k % sample_batch_size == 0:
                shape = self.log_potentials.shape
                B = shape[0]
                s_log_potentials = (
                    self.log_potentials.reshape(1, *shape)
                    .repeat(sample_batch_size, *tuple(1 for _ in shape))
                    .reshape(-1, *shape[1:])
                )

                s_lengths = self.lengths
                if s_lengths is not None:
                    s_shape = s_lengths.shape
                    s_lengths = (
                        s_lengths.reshape(1, *s_shape)
                        .repeat(sample_batch_size, *tuple(1 for _ in s_shape))
                        .reshape(-1, *s_shape[1:])
                    )

                noise = (
                    torch.distributions.Gumbel(0, 1)
                    .sample((sample_batch_size * B, *noise_shape))
                    .expand_as(s_log_potentials)
                ).to(s_log_potentials.device)
                noisy_potentials = (s_log_potentials + noise) / temp

                r_sample = (
                    self._struct(LogSemiring)
                    .marginals(noisy_potentials, s_lengths)
                    .reshape(sample_batch_size, B, *shape[1:])
                )
                samples.append(r_sample)
        return torch.cat(samples, dim=0)[:nsamples]

Let me know if you'd like me to submit as a pr (with whatever changes you think make sense).

Thanks, Tom

srush commented 3 years ago

Awesome sounds like we have three different methods. The one in my PR is from Yao's NeurIPS work https://arxiv.org/abs/2011.14244 which is unbiased forward and biased backward. Maybe we should have a phone call and figure out the differences and how to document and compare them.

teffland commented 3 years ago

Very interesting, I'll take a look at the paper -- unbiased forward sounds like a big plus. I'm available for a call to discuss pretty much whenever.

RafaelPo commented 3 years ago

Not sure how what I proposed compares to the rest, it seems (way) more computationally expensive but I would be interested in a call as well, but I am based in England.