Closed ziyiyin97 closed 2 years ago
using JUDI
using ArgParse, Test, Printf, Aqua
using SegyIO, LinearAlgebra, Distributed, JOLI
using TimerOutputs: TimerOutputs, @timeit
using Flux
# JUDI seismic utils
include(joinpath(JUDIPATH,"../test/seismic_utils.jl"))
### Model
nsrc = 1
dt = 1f0
model, model0, dm = setup_model(false, false, 4)
m, m0 = model.m.data, model0.m.data
q, srcGeometry, recGeometry, f0 = setup_geom(model; nsrc=nsrc, dt=dt)
# Common op
F = judiModeling(model, srcGeometry, recGeometry)
d_obs = F*q
g = gradient(Flux.params(m0)) do
ϕ = .5f0*norm(F(m0, q) - d_obs)^2
return ϕ
end
This does an extra adjoint
to take gradient w.r.t. q
which is not needed
julia> include("MFE.jl")
Operator `forward` ran in 0.31 s
Operator `forward` ran in 0.35 s
Operator `forward` ran in 0.49 s
Operator `gradient` ran in 0.40 s
Operator `adjoint` ran in 0.33 s
Grads(...)
MFE is here https://julialang.slack.com/archives/C7LFJTXV5/p1655845955414799 doesn't seem to be a problem w/ JUDI
Correct me if wrong but we don't have an AD rule for point source modeling, like https://github.com/slimgroup/JUDI4Flux.jl/blob/a21e71f35ba861618d9e7d02ef4e538bb036dc71/src/JUDI4Flux.jl#L125, right? i.e. we only have gradient w.r.t. m not q. So we need a couple of lines as addition to rrule.jl?