cybertronai / autograd-hacks

The Unlicense
151 stars 31 forks source link

Support gradient computation in multiple forward passes #4

Closed jusjusjus closed 4 years ago

jusjusjus commented 4 years ago

I rewrote parts of your module to allow the computation of gradients from multiple forward passes. The use case is best summarized in the test, which I added:

def test_grad1_for_multiple_passes():
    torch.manual_seed(42)
    model = Net()
    loss_fn = nn.CrossEntropyLoss()

    def get_data(batch_size):
        return (torch.rand(batch_size, 1, 28, 28),
                torch.LongTensor(batch_size).random_(0, 10))

    n1 = 4
    n2 = 10

    autograd_hacks.add_hooks(model)

    data, targets = get_data(n1)
    output = model(data)
    loss_fn(output, targets).backward(retain_graph=True)
    grads = [{n: p.grad.clone() for n, p in model.named_parameters()}]
    model.zero_grad()

    data, targets = get_data(n2)
    output = model(data)
    loss_fn(output, targets).backward(retain_graph=True)
    grads.append({n: p.grad for n, p in model.named_parameters()})

    autograd_hacks.compute_grad1(model)

    autograd_hacks.disable_hooks()

    for n, p in model.named_parameters():
        for i, grad in enumerate(grads):
            assert grad[n].shape == p.grad1[i].shape[1:]
            assert torch.allclose(grad[n], p.grad1[i].mean(dim=0))

The pull request also addresses Issue #3. This, of course, requires us to have another index on property grad1 which now becomes a list. See if you like this major change in your API.

In the future, I'd also like to release this module on PyPI for maintenance reasons.