HEmile / storchastic

Stochastic Automatic Differentiation library for PyTorch.
GNU General Public License v3.0
180 stars 5 forks source link

Cleaner overriding of Tensors #59

Closed HEmile closed 4 years ago

HEmile commented 4 years ago

Currently, storch.Tensor wraps around PyTorch Tensors using monkey patching. This can cause issues when other patches exist, compatibilities, or makes it difficult to update Storchastic to other Pytorch versions.

It looks like PyTorch has worked on this problem.

Documentation

https://pytorch.org/docs/master/notes/extending.html (see: Extending torch)

The __torch_function__ method is called whenever a torch method is called (also a torch.Tensor() method on the overriden tensor?). It passes the function, the types of the arguments and the args and kwargs given to the function. This should be enough information to apply the wrapper!

There is a part on Extending torch with a Tensor wrapper type. Looking at that example, that should be enough :)

This is not yet in PyTorch 1.4. I should probably wait until it's merged into stable (which should be for PyTorch 1.5 as it's already in the master), and then re-implement the monkey patching using this method.

https://github.com/pytorch/pytorch/issues/22402:

https://github.com/pytorch/pytorch/issues/24015

Contains a dispatch mechanism to dispatch Tensor-like objects that handle the torch.somefunction call the way they see fit: ie, torch functions with a tensor input parameter become overridable, without overhead when the input IS Tensor (nice!).

https://github.com/pytorch/pytorch/commit/d12786b24fc7df0526c9fcd69efa776c53c34e3f

Writes how to do this (see extending.rst)

More sources: