pytorch / tutorials

PyTorch tutorials.
https://pytorch.org/tutorials/
BSD 3-Clause "New" or "Revised" License
8.22k stars 4.07k forks source link

[BUG] - Per sample gradients using function transforms not working for RNN #2566

Closed bnuliujing closed 1 year ago

bnuliujing commented 1 year ago

Add Link

Hello! I'm working on a optimization algorithm that requires computing the per sample gradients. Assuming the batch size is $N$ and the number of model parameters is $M$, I want to calculate $\partial \log p(\mathbf{x}^{(i)};\theta)/\partial \theta_j$, which is an $N \times M$ matrix. I found the [PER-SAMPLE-GRADIENTS](https://pytorch.org/tutorials/intermediate/per_sample_grads.html) tutorial and began my own experiments. As a proof of concept, I defined a generative model with a tractable likelihood, such as MADE (Masked Autoencoder for Distribution Estimation), PixelCNN, RNN, etc., and sepcified the log_prob and sample methods. I utilized the function transforms methods mentioned in the tutorial, but currently, it only works for MADE (I believed it would work for NADE and PixelCNN too, since these models need only one forward pass to calculate the log likelihood of $\mathbf{x}$. For RNN however, both sampling and inference require $N$ forward pass). Below, I've provided my code snippets, and I'm interested in figuring out why it's not working for RNN. Making it work for RNN would significantly reduce the number of parameters for my research purpose. Thank you!

Describe the bug

import math

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(0)

class MADE(nn.Module):
    '''A simple one-layer MADE (Masked Autoencoder for Distribution Estimation)'''

    def __init__(self, n=10, device='cpu', *args, **kwargs):
        super().__init__()
        self.n = n
        self.device = device

        self.weight = nn.Parameter(torch.randn(self.n, self.n) / math.sqrt(self.n))
        self.bias = nn.Parameter(torch.zeros(self.n))
        mask = torch.tril(torch.ones(self.n, self.n), diagonal=-1)
        self.register_buffer('mask', mask)

    def pred_logits(self, x):
        return F.linear(x, self.mask * self.weight, self.bias)

    def forward(self, x):
        logits = self.pred_logits(x)
        log_probs = - F.binary_cross_entropy_with_logits(logits, x, reduction='none')
        return log_probs.sum(-1)

    @torch.no_grad()
    def sample(self, batch_size):
        x = torch.zeros(batch_size, self.n, dtype=torch.float, device=self.device)
        for i in range(self.n):
            logits = self.pred_logits(x)[:, i]
            x[:, i] = torch.bernoulli(torch.sigmoid(logits))
        return x

class GRUModel(nn.Module):
    '''GRU for density estimation'''

    def __init__(self, n=10, input_size=2, hidden_size=8, device='cpu'):
        super().__init__()
        self.n = n
        self.input_size = input_size  # input_size=2 when x is binary
        self.hidden_size = hidden_size
        self.device = device
        self.gru_cell = nn.GRUCell(self.input_size, self.hidden_size)
        self.fc_layer = nn.Linear(self.hidden_size, 1)

    def pred_logits(self, x, h=None):
        x = torch.stack([x, 1 - x], dim=1)  # 1 -> (1, 0), 0 -> (0, 1), (batch_size, 2)
        h_next = self.gru_cell(x, h)  # h_{i+1}
        logits = self.fc_layer(h_next).squeeze(1)
        return h_next, logits

    def forward(self, x):
        log_prob_list = []
        x = torch.cat([torch.zeros(x.shape[0], 1, dtype=torch.float, device=self.device), x], dim=1)  # cat x_0
        h = torch.zeros(x.shape[0], self.hidden_size, dtype=torch.float, device=self.device)  # h_0
        for i in range(self.n):
            h, logits = self.pred_logits(x[:, i], h)
            log_prob = - F.binary_cross_entropy_with_logits(logits, x[:, i + 1], reduction='none')
            log_prob_list.append(log_prob)
        return torch.stack(log_prob_list, dim=1).sum(dim=1)

    @torch.no_grad()
    def sample(self, batch_size):
        x = torch.zeros(batch_size, self.n + 1, dtype=torch.float, device=self.device)
        for i in range(self.n):
            h, logits = self.pred_logits(x[:, i], h=None if i == 0 else h)
            x[:, i + 1] = torch.bernoulli(torch.sigmoid(logits))
        return x[:, 1:]

if __name__ == '__main__':
    model = MADE()
    # model = GRUModel()

    # Sample from the generative model
    samples = model.sample(128)

    # Then I use the function transforms methods mentioned in the tutorial
    # to calculate the per sample mean
    from torch.func import functional_call, grad, vmap
    params = {k: v.detach() for k, v in model.named_parameters()}

    def loss_fn(log_probs):
        return log_probs.mean(0)

    def compute_loss(params, sample):
        batch = sample.unsqueeze(0)
        log_prob = functional_call(model, (params,), (batch,))
        loss = loss_fn(log_prob)
        return loss

    ft_compute_grad = grad(compute_loss)
    ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, 0))
    ft_per_sample_grads = ft_compute_sample_grad(params, samples)

    print(ft_per_sample_grads)

The above code works for MADE (I also check the values of gradients, they are correct!) However, when I use model = GRUModel(), an error arises:

Traceback (most recent call last):
  File "per_sample_grads.py", line 100, in <module>
    ft_per_sample_grads = ft_compute_sample_grad(params, samples)
  File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/_functorch/vmap.py", line 434, in wrapped
    return _flat_vmap(
  File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/_functorch/vmap.py", line 39, in fn
    return f(*args, **kwargs)
  File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/_functorch/vmap.py", line 619, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/_functorch/eager_transforms.py", line 1380, in wrapper
    results = grad_and_value(func, argnums, has_aux=has_aux)(*args, **kwargs)
  File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/_functorch/vmap.py", line 39, in fn
    return f(*args, **kwargs)
  File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/_functorch/eager_transforms.py", line 1245, in wrapper
    output = func(*args, **kwargs)
  File "per_sample_grads.py", line 94, in compute_loss
    log_prob = functional_call(model, (params,), (batch,))
  File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/_functorch/functional_call.py", line 143, in functional_call
    return nn.utils.stateless._functional_call(
  File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/nn/utils/stateless.py", line 262, in _functional_call
    return module(*args, **kwargs)
  File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "per_sample_grads.py", line 63, in forward
    h, logits = self.pred_logits(x[:, i], h)
  File "per_sample_grads.py", line 54, in pred_logits
    h_next = self.gru_cell(x, h)  # h_{i+1}
  File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 1327, in forward
    ret = _VF.gru_cell(
RuntimeError: output with shape [1, 8] doesn't match the broadcast shape [128, 1, 8]

Describe your environment

The above code is also tested on Ubuntu 18.04 with PyTorch 2.0.1, CUDA 11.7/11.8.

svekars commented 1 year ago

Hi @bnuliujing, can you please post in https://dev-discuss.pytorch.org?