Closed aleaverfay closed 1 year ago
Patch coverage: 100.00
% and no project coverage change.
Comparison is base (
1ebf577
) 95.18% compared to head (931e73b
) 95.18%.
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.
It occurs to me that the non-zero stride is a necessary but not sufficient condition for well-allocated tensor casts; if the strides do not line up with the size of the object that we're casting to, e.g. if we have a stride of N != 1 for the last dimension of a coordinate tensor, e.g. an [n-pose x max-n-atoms x 3 x 20] coords tensor that we sliced as coords[:, :, :, 7] for the eighth set of coordinates, then when we cast to an Eigen::Matrix<Real, 3, 1>, we will actually read the X coordinates for copy 7, 8, and 9, instead of the X, Y and Z coordinates of the 7th copy. I think we should actually be asserting that the strides are the cumprod(j=0, j< i - 1, size[j]), i.e. that there is a contiguous block of memory that we are casting to
In its backward-pass, torch.sum(t) returns a scalar, t2, which masquerades as a tensor with the same shape as t, except its strides will give it away as not being a real tensor: they will all be 0. But, if t is a tensor of coordinates and we use TCAST to coerce the t2 tensor into a TView<Eigen::Matrix<float, 3, 1>, N, D>, then the fact that dE_dx and dE_dy and dE_dz all have the same address (because a stride of 0 means "ignore the index in the given dimension" and a tensor with only strides of zeros means "ignore all the indices used to look up data, there is only one value") will be LOST and the dE_dy and dE_dz values will be read off the end of the array! Oh no!
To fix this, we can simply look at any consumed dimension as determined by the enable_tensor_view<SomeClass> class and assert that the consumed dimension(s) has (have) a non-zero stride. Any other stride besides zero is fine.