Currently, storch.Tensor wraps around PyTorch Tensors using monkey patching. This can cause issues when other patches exist, compatibilities, or makes it difficult to update Storchastic to other Pytorch versions.
The __torch_function__ method is called whenever a torch method is called (also a torch.Tensor() method on the overriden tensor?). It passes the function, the types of the arguments and the args and kwargs given to the function. This should be enough information to apply the wrapper!
There is a part on Extending torch with a Tensor wrapper type. Looking at that example, that should be enough :)
This is not yet in PyTorch 1.4. I should probably wait until it's merged into stable (which should be for PyTorch 1.5 as it's already in the master), and then re-implement the monkey patching using this method.
Contains a dispatch mechanism to dispatch Tensor-like objects that handle the torch.somefunction call the way they see fit: ie, torch functions with a tensor input parameter become overridable, without overhead when the input IS Tensor (nice!).
Currently,
storch.Tensor
wraps around PyTorch Tensors using monkey patching. This can cause issues when other patches exist, compatibilities, or makes it difficult to update Storchastic to other Pytorch versions.It looks like PyTorch has worked on this problem.
Documentation
https://pytorch.org/docs/master/notes/extending.html (see: Extending torch)
The
__torch_function__
method is called whenever atorch
method is called (also atorch.Tensor()
method on the overriden tensor?). It passes the function, the types of the arguments and the args and kwargs given to the function. This should be enough information to apply the wrapper!There is a part on Extending torch with a Tensor wrapper type. Looking at that example, that should be enough :)
This is not yet in PyTorch 1.4. I should probably wait until it's merged into stable (which should be for PyTorch 1.5 as it's already in the master), and then re-implement the monkey patching using this method.
https://github.com/pytorch/pytorch/issues/22402:
Tensor
subclasses when callingtorch
functions on them." This sounds exactly like what we need.https://github.com/pytorch/pytorch/issues/24015
Contains a dispatch mechanism to dispatch
Tensor
-like objects that handle thetorch.somefunction
call the way they see fit: ie, torch functions with a tensor input parameter become overridable, without overhead when the input ISTensor
(nice!).https://github.com/pytorch/pytorch/commit/d12786b24fc7df0526c9fcd69efa776c53c34e3f
Writes how to do this (see extending.rst)
More sources: