I rewrote parts of your module to allow the computation of gradients from multiple forward passes. The use case is best summarized in the test, which I added:
def test_grad1_for_multiple_passes():
torch.manual_seed(42)
model = Net()
loss_fn = nn.CrossEntropyLoss()
def get_data(batch_size):
return (torch.rand(batch_size, 1, 28, 28),
torch.LongTensor(batch_size).random_(0, 10))
n1 = 4
n2 = 10
autograd_hacks.add_hooks(model)
data, targets = get_data(n1)
output = model(data)
loss_fn(output, targets).backward(retain_graph=True)
grads = [{n: p.grad.clone() for n, p in model.named_parameters()}]
model.zero_grad()
data, targets = get_data(n2)
output = model(data)
loss_fn(output, targets).backward(retain_graph=True)
grads.append({n: p.grad for n, p in model.named_parameters()})
autograd_hacks.compute_grad1(model)
autograd_hacks.disable_hooks()
for n, p in model.named_parameters():
for i, grad in enumerate(grads):
assert grad[n].shape == p.grad1[i].shape[1:]
assert torch.allclose(grad[n], p.grad1[i].mean(dim=0))
The pull request also addresses Issue #3.
This, of course, requires us to have another index on property grad1 which now becomes a list. See if you like this major change in your API.
In the future, I'd also like to release this module on PyPI for maintenance reasons.
I rewrote parts of your module to allow the computation of gradients from multiple forward passes. The use case is best summarized in the test, which I added:
The pull request also addresses Issue #3. This, of course, requires us to have another index on property
grad1
which now becomes a list. See if you like this major change in your API.In the future, I'd also like to release this module on PyPI for maintenance reasons.