pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.56k stars 22.54k forks source link

Support differentiability through clone and update #65913

Open tchaton opened 3 years ago

tchaton commented 3 years ago

🚀 Feature

Dear people from PyTorch,

I was reading through the Learn2Learn codebase for Meta-Learning and came across an interesting hack they had to implement.

https://github.com/learnables/learn2learn/blob/master/learn2learn/utils/__init__.py#L53

    net = nn.Sequential(Linear(20, 10), nn.ReLU(), nn.Linear(10, 2))
    clone = clone_module(net)
    error = loss(clone(X), y)
    error.backward()  # Gradients are back-propagate all the way to net.

The second trick is implemented here: https://github.com/learnables/learn2learn/blob/master/learn2learn/utils/__init__.py#L230 to perform the update and preserve differentiability

I believe it would be interesting for PyTorch to support this natively.

Here could be the API for cloning:

model.clone(keep_differentiability=False/True)

Here could be the API for the function optimization API:

F.sgd(params, grads, keep_differentiability=False/True)

Motivation

Pitch

Alternatives

Additional context

cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer @Lezcano @Varal7 @mruberry @jbschlosser @walterddr

tchaton commented 3 years ago

@seba-1511

albanD commented 3 years ago

Hi,

Thanks for sharing these details. We are indeed looking into providing these functionalities (via different APIs). You can see https://github.com/pytorch/pytorch/issues/39279 and https://github.com/pytorch/pytorch/issues/49171 for example.

These two would provide similar functionality as what you're looking for right?