slimgroup / JUDI.jl

Julia Devito inversion.
https://slimgroup.github.io/JUDI.jl
MIT License
96 stars 29 forks source link

Gradients w.r.t. both m and q are computed although only one of them is Flux.params #123

Closed ziyiyin97 closed 2 years ago

ziyiyin97 commented 2 years ago

Correct me if wrong but we don't have an AD rule for point source modeling, like https://github.com/slimgroup/JUDI4Flux.jl/blob/a21e71f35ba861618d9e7d02ef4e538bb036dc71/src/JUDI4Flux.jl#L125, right? i.e. we only have gradient w.r.t. m not q. So we need a couple of lines as addition to rrule.jl?

mloubout commented 2 years ago

And what's that

https://github.com/slimgroup/JUDI.jl/blob/master/test/test_rrules.jl#L112

mloubout commented 2 years ago

Or this https://github.com/slimgroup/JUDI.jl/blob/master/src/rrules.jl#L6

ziyiyin97 commented 2 years ago
using JUDI
using ArgParse, Test, Printf, Aqua
using SegyIO, LinearAlgebra, Distributed, JOLI
using TimerOutputs: TimerOutputs, @timeit
using Flux

# JUDI seismic utils
include(joinpath(JUDIPATH,"../test/seismic_utils.jl"))

### Model
nsrc = 1
dt = 1f0

model, model0, dm = setup_model(false, false, 4)
m, m0 = model.m.data, model0.m.data
q, srcGeometry, recGeometry, f0 = setup_geom(model; nsrc=nsrc, dt=dt)

# Common op
F = judiModeling(model, srcGeometry, recGeometry)
d_obs = F*q

g = gradient(Flux.params(m0)) do
    ϕ = .5f0*norm(F(m0, q) - d_obs)^2
    return ϕ
end

This does an extra adjoint to take gradient w.r.t. q which is not needed

julia> include("MFE.jl")
Operator `forward` ran in 0.31 s
Operator `forward` ran in 0.35 s
Operator `forward` ran in 0.49 s
Operator `gradient` ran in 0.40 s
Operator `adjoint` ran in 0.33 s
Grads(...)
ziyiyin97 commented 2 years ago

MFE is here https://julialang.slack.com/archives/C7LFJTXV5/p1655845955414799 doesn't seem to be a problem w/ JUDI