pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.15k stars 22.42k forks source link

Flattening nn.Parameters while maintaining gradients from neural network forward pass #49723

Open 18jeffreyma opened 3 years ago

18jeffreyma commented 3 years ago

🚀 Feature

Given some neural network a, while a.parameters() returns a list of Tensors for each parameter in the network, an alternate parameter return a.parameters_flattened(), which returns a Nx1 tensor, which shares the same elementwise gradients as the tensors in a.parameters().

Motivation

When defining custom gradient descent methods and optimizers for networks (like here its often much easier (and cleaner) to assume the parameters of some network are a Nx1 vector, which then we can perform matrix operations, allowing us to reduce overhead and casework from having the parameters be split up into multiple tensors.

However, if we try to reshape the list of tensors returned by network.parameters(), this creates a new computation branch, which does not share gradients with a normal forward pass of the neural network. Is there some way of returning the parameters of the network as a tensor that shares the same gradients as the network parameters?

Pitch

Basically, in addition to network.parameters() returning a list of tensors that represent the weights/tensors of the neural net, I was wondering if there exists something like network_parameters_flattened() which contains the same weights, but flattened, which also share the gradient with a forward pass of the network.

Alternatives

In defining a custom optimizer, which optimizes the weights multiple players, I've just instead used a less efficient representation, where if I just represent a player's parameters as lists of tensors, and map a function over the list whenever I need to add two players parameters together. An idea of this is here.

Additional context

Specifically for the research community in prototyping, I think this would be an extremely feature, but I understand if it might not be possible given how reshaping and concatenation produce new tensors by default.

cc @albanD @mruberry @jbschlosser @vincentqb

ngimel commented 3 years ago

This is a popular request, and there's even an add-on github repo that attempts to do it https://github.com/PhilJd/contiguous_pytorch_params. However, this is unlikely to be implemented in the core. "Debugging" section in the linked repo gives some reasons why - basically, many things could go wrong and gradients will end up being disconnected. To elaborate on and add to reasons listed there, here's an incomplete list of situations when parameters (and gradients) cannot be represented as a flat buffer: 1) parameters of the different types 2) parameters on the different devices 3) alignment of the parameters in the flat list is insufficient, hurting performance 4) model is modified (moved to a different device, converted to a different dtype)

In addition, if you also want gradients in the flattened buffer

5) not all the parameters are used in the forward path, and hence they are not getting gradients 6) gradients can be returned as a broadcasted tensor, so physically they occupy less memory than the corresponding parameter 7) current implementation does not allow to put gradient in the pre-determined location, so a copy will be required.

In short, attempts to flatten parameters have been made, and every time there were corner cases that made it very difficult to do so robustly. But there's good news - pytorch now provides APIs that allow you to operate directly on the lists of tensors (which can be your parameter lists or gradient lists), look e.g. how an optimizer is implemented using these for_each APIs: https://github.com/pytorch/pytorch/blob/master/torch/optim/_multi_tensor/sgd.py For now, for_each APIs support only pointwise operations, so no dot or norm or addr, but we can consider adding those if there are usecases.

18jeffreyma commented 3 years ago

Thanks for the link to the add-on library and for answering my question so quickly! Is there some documentation page where I can find these _foreach* APIs in more detail?

ngimel commented 3 years ago

Unfortunately there are no docs yet, you can look in native_functions.yaml to see which functions are supported. cc @izdeby