DedalusProject / dedalus

A flexible framework for solving PDEs with modern spectral methods.
http://dedalus-project.org/
GNU General Public License v3.0
492 stars 115 forks source link

symbolic differentiation of scalars wrt tensors #211

Closed liamoconnor9 closed 2 years ago

liamoconnor9 commented 2 years ago

It'd be nice if we could differentiate dot products with respect to their (vector) arguments. This would be particularly useful for the adjoint-looping optimization work I'm doing.

Here's an example of the functionality I'm looking for:

    coords = d3.CartesianCoordinates('x', 'z')
    dist = d3.Distributor(coords, dtype=np.float64)
    xbasis = d3.RealFourier(coords['x'], size=64, bounds=(0, 2*np.pi), dealias=3/2)
    zbasis = d3.Chebyshev(coords['z'], size=64, bounds=(-1, 1), dealias=3/2)
    x, z = dist.local_grids(xbasis, zbasis)

    u = dist.VectorField(coords, name='u', bases=(xbasis, zbasis))
    u['g'][0] = np.sin(x)*z
    u['g'][1] = np.cos(x)*z**2

    usqrd = d3.dot(u, u)

    assert np.allclose(usqrd.sym_diff(u).evaluate()['g'], 2*u['g'])

Given a TensorField u, taking u.sym_diff(u) returns 1. If instead it returned an identity tensor of the proper rank, then the test case above would pass

kburns commented 2 years ago

The correct way to do this is to compute the Frechet differential, not to use sym_diff.

kburns commented 2 years ago

Whoops didn't mean to close

liamoconnor9 commented 2 years ago

The correct way to do this is to compute the Frechet differential, not to use sym_diff.

Ok my mistake. I think we can recycle a fair amount of existing code to make this in Dedalus. Would that be a useful feature? Or should I add that feature to my own project outside of Dedalus?

kburns commented 2 years ago

This already works -- it's how the Jacobian is computed for the NLBVP: https://github.com/DedalusProject/dedalus/blob/579b32b774eac77db9abbd1db9f546f171ece743/dedalus/core/problems.py#L397

liamoconnor9 commented 2 years ago

Thanks for pointing me in the right direction, I'll take a look at that.