Closed alvorithm closed 4 years ago
Re. AffineTransform
replacing AffineScalarTransform
the test fails on the logabsdet
of the inverse in standard_test.py::AffineScalarTransformTest
. I could have a look tomorrow, if you don't immediately see the reason.
Ok, the commits above implement all the changes we've discussed so far
AffineScalarTransform
and AffineTransform
under PointwiseAffineTransform
.Note I had to special-case scalar scale
to match the expectation of the tests to within eps=1E-6
. Using the .sum()
idiom leads to disagreements ~1E-5
, and is perhaps slower. But now the special casing is confined to a little function, where it can be suitably documented. Also, this function could be memoized if the transform is applied multiple times (or even in fwd and inv modes) with the same batch_shape
.
Thanks, Artur, and please keep the nitpicks coming, I love them!
In this installment
.size()
instead of .shape
, since a) we are demanding inputs
to be tensors and b) in preparation for named tensors, which will be supported by .size()
, but I reckon not by .shape
.torch
did not follow numpy terminology here, because as a consequence this reads slightly confusing:
batch_size, *batch_shape = inputs.size()
batch_len, *batch_size = inputs.size()
prod
dance in a numel
function in torchutils that works with Tensors or Tensor Sizes.README.md
as entry point to understand the package.torchutils
I reimplemented notinfnotnan
using torch.isfinite
. Since the idiom is so short and self-explaining, and there are no uses in the code that I could find, I actually would suggest to remove the function.ensure_tensor
since torch.as_tensor
does what we want.numel
. Caveat: <my_tensor>.numel()
and <my_size>.numel()
do each the right thing, but you have to ensure then that you have a size.notinfnotnan
, as Tensor
(I learnt) cannot be subscripted. We could solve this defining own types, but postponing this since notinfnotnan
does not seem necessary anymore.
This PR adds a little helper function to prevent tensor construction warnings when passing a
torch.Tensor
-typed argument totorch.tensor()
.