Closed stefansommer closed 5 years ago
Thanks a lot for raising this. Having had a look at the tensordot adjoints, they don't look like they're ignoring the axes, so there must be a bug in the way the axes are handled. I'm about to be on holiday for a week but will try to take a look when I get back.
In the meantime, we welcome pull requests if you fancy having a go at fixing. The tensordot tests (which are clearly missing the above case) are here.
You are welcome. How does this look: (the second clause in the if removed)
@primitive
def tensordot_adjoint_0(B, G, axes, A_ndim, B_ndim):
# The adjoint of the operator
# A |--> np.tensordot(A, B, axes)
if B_ndim == 0:
return G * B
G_axes = onp.arange(onp.ndim(G))
if type(axes) is int:
axes = max(axes, 0)
B_axes = onp.arange(B_ndim)
return onp.tensordot(G, B, [G_axes[A_ndim-axes:], B_axes[axes:]])
else:
axes0 = [axes[0]] if type(axes[0]) is int else axes[0]
axes1 = [axes[1]] if type(axes[1]) is int else axes[1]
axes = [axes0, axes1]
A_axes = onp.arange(A_ndim)
B_axes = onp.arange(B_ndim)
summed_axes = [onp.asarray(axes[0]) % A_ndim,
onp.asarray(axes[1]) % B_ndim]
other_axes = [onp.delete(A_axes, summed_axes[0]),
onp.delete(B_axes, summed_axes[1])]
out = onp.tensordot(G, B, [G_axes[len(other_axes[0]):], other_axes[1]])
perm = onp.argsort(onp.concatenate(
(other_axes[0], summed_axes[0][onp.argsort(summed_axes[1])])))
return onp.transpose(out, perm)
@primitive
def tensordot_adjoint_1(A, G, axes, A_ndim, B_ndim):
# The adjoint of the operator
# B |--> np.tensordot(A, B, axes)
if A_ndim == 0:
return G * A
G_axes = onp.arange(onp.ndim(G))
if type(axes) is int:
axes = max(axes, 0)
A_axes = onp.arange(A_ndim)
return onp.tensordot(A, G, [A_axes[:A_ndim-axes], G_axes[:A_ndim-axes]])
else:
axes0 = [axes[0]] if type(axes[0]) is int else axes[0]
axes1 = [axes[1]] if type(axes[1]) is int else axes[1]
axes = [axes0, axes1]
A_axes = onp.arange(A_ndim)
B_axes = onp.arange(B_ndim)
summed_axes = [onp.asarray(axes[0]) % A_ndim,
onp.asarray(axes[1]) % B_ndim]
other_axes = [onp.delete(A_axes, summed_axes[0]),
onp.delete(B_axes, summed_axes[1])]
out = onp.tensordot(A, G, [other_axes[0], G_axes[:len(other_axes[0])]])
perm = onp.argsort(onp.concatenate(
(summed_axes[1][onp.argsort(summed_axes[0])], other_axes[1])))
return onp.transpose(out, perm)
def test_tensordot_8(): combo_check(np.tensordot, [0, 1], order=3)([R(2)], [R(2,2)], axes=[[0, 1]])
Great job! Would you like to submit a pull request, or shall I just merge these changes (and credit you in the commit message)?
Thanks, I've submitted a pull request.
As far as I can see, the axis specification of tensordot is ignored when computing derivatives. In this example, the Jacobian is of the wrong sign (it takes derivative of
-x
instead ofx
):