Scheduled sampling for PyTorch

This repository implements Python code and C++ and CUDA extensions for efficient importance sampling from either a discrete distribution, or a product of independent discrete distributions. The main reason this is nontrivial is that it ensures the samples are all distinct; and doing so is not super-trivial for a product of discrete distributions.



pip install torch-discounted-cumsum


see [torch_scheduled_sampling/sampling.py]: def sample_combined(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tensor, Tensor, Tensor]: """ Sample from a distribution that is the product of softmaxes. We will sample K distinct samples. This entails using sampling weights of the form min(1, p/beta) for a computed beta. Args: p: A Tensor of shape (*, N, M): either normalized log-probs (if input_is_log==False), or normalized probabilities; normalized along the M axis. N must be in [1,2,3,4]; in the common case, N will be 1, you can use unsqueeze().

     K: An integer, the number of samples required, with 0 < K < M

input_is_log: True if p represents normalized log-probs, False if it represents probabilities.

Returns: (indexes, combined_indexes, weights)
  indexes: of shape (*, K, N), for each of K samples from a distribution it contains
        an N-tuple of indexes saying which combination of indexes from the
        component distributions were sampled.
  combined_indexes: of shape(*, K),  contains the same information as `indexes` but
        in a different format, specifically:
           `combined_indexes[...,k] = sum_n indexes[...,k,n] * M**n`
   weights: of shape (*, K), gives the weight associated with each sample,
        which will equal max(p, beta) for a beta specific to the batch element,
        i.e. to the input product distribution (0 < beta <= 1/K).  The
        weights will sum to 1 along the K axis.  The class-inclusion
        probabilities in the sample would be given by (p / weights), assuming
        input_is_log == False, or (p.exp() / weights) otherwise; but
        we output them in this format as it is less likely to lead to
        large numbers in backprop.


import torch
from torch_scheduled_sampling import sample_combined

logprobs = (2 * torch.randn(3, 1, 8)).log_softmax(dim=2)

indexes, _combined_indexes, weights = sample_combined(logprobs, K=2, input_is_log=True)

importance_logprobs = weights.log() - torch.gather(logprobs, dim=2, index=indexes.transpose(1, 2)).squeeze(1)


torch.Size([3, 2, 1])  # indexes.shape
tensor([[0, 6],
        [1, 6],
        [2, 4]])    # indexes.squeeze(-1)
tensor([[0.7474, 0.2526],
        [0.1940, 0.8060],
        [0.5000, 0.5000]])  # weights.squeeze(-1)
tensor([[ 8.9407e-08,  1.7194e-01],
        [ 1.0342e+00, -2.9802e-08],
        [ 5.6074e-01,  5.0099e-01]])  # importance_logprobs