HEmile / storchastic

Stochastic Automatic Differentiation library for PyTorch.
GNU General Public License v3.0
180 stars 5 forks source link

Proper wrapping of PyTorch using __torch_function__ #60

Closed HEmile closed 4 years ago

HEmile commented 4 years ago

Now that PyTorch 1.5 introduced the __torch_function__ magic method that is called when torch functions are called with Tensor-likes, we can wrap around PyTorch much cleaner.

This branch removes the earlier monkey patching design in favour of this __torch_function__ design. Quick tests show that this design works very well. One caveat is in creating Distributions using storch Tensors. The broadcast_all method does not support Tensor-likes, and requires monkey patching either that method or torch.is_tensor.

Closes #59.