HIPS / autograd

Efficiently computes derivatives of NumPy code.
MIT License
7k stars 912 forks source link

tensordot and axes #466

Closed stefansommer closed 5 years ago

stefansommer commented 5 years ago

As far as I can see, the axis specification of tensordot is ignored when computing derivatives. In this example, the Jacobian is of the wrong sign (it takes derivative of -x instead of x):

(lambda x: np.tensordot(np.array([1,0]),np.array([[0,-x],[x,0]]),(0,1)))(1.)
>>> array([0., 1.])

jacobian(lambda x: np.tensordot(np.array([1,0]),np.array([[0,-x],[x,0]]),(0,1)))(1.)
>>> array([ 0., -1.])

 >>> check_grads(lambda x: np.tensordot(np.array([1,0]),np.array([[0,-x],[x,0]]),(0,1)))(1.)
Traceback (most recent call last):
...
AssertionError: Derivative (VJP) check of jvp_<lambda> failed with arg
(1.0, array(-0.43170061)):
analytic: -0.1353787902305393
numeric:  0.13537879021567292
j-towns commented 5 years ago

Thanks a lot for raising this. Having had a look at the tensordot adjoints, they don't look like they're ignoring the axes, so there must be a bug in the way the axes are handled. I'm about to be on holiday for a week but will try to take a look when I get back.

In the meantime, we welcome pull requests if you fancy having a go at fixing. The tensordot tests (which are clearly missing the above case) are here.

stefansommer commented 5 years ago

You are welcome. How does this look: (the second clause in the if removed)

@primitive
def tensordot_adjoint_0(B, G, axes, A_ndim, B_ndim):
    # The adjoint of the operator
    # A |--> np.tensordot(A, B, axes)
    if B_ndim == 0:
        return G * B

    G_axes = onp.arange(onp.ndim(G))
    if type(axes) is int:
        axes = max(axes, 0)
        B_axes = onp.arange(B_ndim)
        return onp.tensordot(G, B, [G_axes[A_ndim-axes:], B_axes[axes:]])
    else:
        axes0 = [axes[0]] if type(axes[0]) is int else axes[0]
        axes1 = [axes[1]] if type(axes[1]) is int else axes[1]
        axes = [axes0, axes1]
        A_axes = onp.arange(A_ndim)
        B_axes = onp.arange(B_ndim)
        summed_axes = [onp.asarray(axes[0]) % A_ndim,
                       onp.asarray(axes[1]) % B_ndim]
        other_axes  = [onp.delete(A_axes, summed_axes[0]),
                       onp.delete(B_axes, summed_axes[1])]
        out = onp.tensordot(G, B, [G_axes[len(other_axes[0]):], other_axes[1]])
        perm = onp.argsort(onp.concatenate(
            (other_axes[0], summed_axes[0][onp.argsort(summed_axes[1])])))
        return onp.transpose(out, perm)

@primitive
def tensordot_adjoint_1(A, G, axes, A_ndim, B_ndim):
    # The adjoint of the operator
    # B |--> np.tensordot(A, B, axes)
    if A_ndim == 0:
        return G * A

    G_axes = onp.arange(onp.ndim(G))
    if type(axes) is int:
        axes = max(axes, 0)
        A_axes = onp.arange(A_ndim)
        return onp.tensordot(A, G, [A_axes[:A_ndim-axes], G_axes[:A_ndim-axes]])
    else:
        axes0 = [axes[0]] if type(axes[0]) is int else axes[0]
        axes1 = [axes[1]] if type(axes[1]) is int else axes[1]
        axes = [axes0, axes1]
        A_axes = onp.arange(A_ndim)
        B_axes = onp.arange(B_ndim)
        summed_axes = [onp.asarray(axes[0]) % A_ndim,
                       onp.asarray(axes[1]) % B_ndim]
        other_axes  = [onp.delete(A_axes, summed_axes[0]),
                       onp.delete(B_axes, summed_axes[1])]
        out = onp.tensordot(A, G, [other_axes[0], G_axes[:len(other_axes[0])]])
        perm = onp.argsort(onp.concatenate(
            (summed_axes[1][onp.argsort(summed_axes[0])], other_axes[1])))
        return onp.transpose(out, perm)

def test_tensordot_8(): combo_check(np.tensordot, [0, 1], order=3)([R(2)], [R(2,2)], axes=[[0, 1]])
j-towns commented 5 years ago

Great job! Would you like to submit a pull request, or shall I just merge these changes (and credit you in the commit message)?

stefansommer commented 5 years ago

Thanks, I've submitted a pull request.