slimgroup / JUDI.jl

Julia Devito inversion.
https://slimgroup.github.io/JUDI.jl
MIT License
96 stars 29 forks source link

Enable scalar/broadcast operation for LazyPropagation #167

Open ziyiyin97 opened 1 year ago

ziyiyin97 commented 1 year ago
  1. Enable scalar/broadcast operation for LazyPropagation; add associated test (which won't pass with the current master)
  2. LazyPropagation now has an attribute val, which stores F * q if previously computed
  3. fix the reshape issue for multi source vector -- which can be in size of nsrc and also in size of nsrc * nt * nrec
codecov[bot] commented 1 year ago

Codecov Report

Base: 81.88% // Head: 81.59% // Decreases project coverage by -0.29% :warning:

Coverage data is based on head (171170f) compared to base (8f65ed4). Patch coverage: 41.17% of modified lines in pull request are covered.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## master #167 +/- ## ========================================== - Coverage 81.88% 81.59% -0.30% ========================================== Files 28 28 Lines 2186 2200 +14 ========================================== + Hits 1790 1795 +5 - Misses 396 405 +9 ``` | [Impacted Files](https://codecov.io/gh/slimgroup/JUDI.jl/pull/167?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=slimgroup) | Coverage Δ | | |---|---|---| | [src/TimeModeling/LinearOperators/lazy.jl](https://codecov.io/gh/slimgroup/JUDI.jl/pull/167?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=slimgroup#diff-c3JjL1RpbWVNb2RlbGluZy9MaW5lYXJPcGVyYXRvcnMvbGF6eS5qbA==) | `83.72% <0.00%> (-0.99%)` | :arrow_down: | | [src/rrules.jl](https://codecov.io/gh/slimgroup/JUDI.jl/pull/167?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=slimgroup#diff-c3JjL3JydWxlcy5qbA==) | `62.02% <40.00%> (-5.14%)` | :arrow_down: | | [src/TimeModeling/Types/abstract.jl](https://codecov.io/gh/slimgroup/JUDI.jl/pull/167?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=slimgroup#diff-c3JjL1RpbWVNb2RlbGluZy9UeXBlcy9hYnN0cmFjdC5qbA==) | `77.41% <100.00%> (+0.18%)` | :arrow_up: | Help us with your feedback. Take ten seconds to tell us [how you rate us](https://about.codecov.io/nps?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=slimgroup). Have a feature suggestion? [Share it here.](https://app.codecov.io/gh/feedback/?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=slimgroup)

:umbrella: View full report at Codecov.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.

mloubout commented 1 year ago

Your change lead to ambiguities... please run these basic tests locally

ziyiyin97 commented 1 year ago

On a side note: does it make sense to move the scalar operations (all of +-*/) into LazyPropagation.post?

mloubout commented 1 year ago

nto LazyPropagation.post?

No because then it's not a linear operation anymore

ziyiyin97 commented 1 year ago

Hmm appreciate your @mloubout comment on this one: I am now on JUDI master and

julia> gs_inv = gradient(x -> norm(F(x)*q), m0)
[ Info: Assuming m to be squared slowness for judiDataSourceModeling{Float32, :forward}
Operator `forward` ran in 0.54 s
Operator `forward` ran in 0.50 s
Operator `gradient` ran in 0.34 s
(Float32[-0.081900775 0.07301128 … 6.170804f-6 7.20752f-6; 0.0637427 0.027981473 … 9.756089f-7 5.4272978f-6; … ; 0.06374304 0.027981216 … 9.755976f-7 5.4272914f-6; -0.08189945 0.07301152 … 6.170794f-6 7.2075245f-6],)

julia> gs_inv1 = gradient(x -> norm(F(1f0*x)*q), m0)
[ Info: Assuming m to be squared slowness for judiDataSourceModeling{Float32, :forward}
Operator `forward` ran in 0.55 s
Operator `forward` ran in 0.49 s
Operator `gradient` ran in 0.34 s
ERROR: MethodError: no method matching *(::Float32, ::JUDI.LazyPropagation)
Closest candidates are:
  *(::Any, ::Any, ::Any, ::Any...) at operators.jl:591
  *(::T, ::T) where T<:Union{Float16, Float32, Float64} at float.jl:385
  *(::Union{Float16, Float32, Float64}, ::BigFloat) at mpfr.jl:414
  ...
Stacktrace:
  [1] (::ChainRules.var"#1490#1494"{JUDI.LazyPropagation, Float32, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}})()
    @ ChainRules ~/.julia/packages/ChainRules/ajkp7/src/rulesets/Base/arraymath.jl:111
  [2] unthunk
    @ ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_types/thunks.jl:204 [inlined]
  [3] unthunk(x::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{ChainRules.var"#1490#1494"{JUDI.LazyPropagation, Float32, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}}, ChainRules.var"#1489#1493"{JUDI.LazyPropagation, Float32}})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_types/thunks.jl:237
  [4] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/SmJK6/src/compiler/chainrules.jl:105 [inlined]
  [5] map
    @ ./tuple.jl:223 [inlined]
  [6] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/SmJK6/src/compiler/chainrules.jl:106 [inlined]
  [7] ZBack
    @ ~/.julia/packages/Zygote/SmJK6/src/compiler/chainrules.jl:206 [inlined]
  [8] Pullback
    @ ./REPL[26]:1 [inlined]
  [9] (::typeof(∂(#10)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
 [10] (::Zygote.var"#60#61"{typeof(∂(#10))})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:45
 [11] gradient(f::Function, args::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:97
 [12] top-level scope
    @ REPL[26]:1
 [13] top-level scope
    @ ~/.julia/packages/CUDA/DfvRa/src/initialization.jl:52

julia> import Base.*;

julia> *(y::Float32, F::JUDI.LazyPropagation) = JUDI.LazyPropagation(F.post, F.F, *(y, F.q));

julia> gs_inv2 = gradient(x -> norm(F(1f0*x)*q), m0)
[ Info: Assuming m to be squared slowness for judiDataSourceModeling{Float32, :forward}
Operator `forward` ran in 0.56 s
Operator `forward` ran in 0.53 s
Operator `gradient` ran in 0.34 s
Operator `forward` ran in 0.43 s
Operator `gradient` ran in 0.35 s
(Float32[-0.081900775 0.07301128 … 6.170804f-6 7.20752f-6; 0.0637427 0.027981473 … 9.756089f-7 5.4272978f-6; … ; 0.06374304 0.027981216 … 9.755976f-7 5.4272914f-6; -0.08189945 0.07301152 … 6.170794f-6 7.2075245f-6],)

gs_inv performs a nonlinear forward modeling and an RTM. gs_inv1 fails because scalar multiplication is not defined yet. After the definition of multiplication, gs_inv2 did 2 evaluations on the LazyPropgation, which confuses me ... any idea why? Thanks

Full script below

```julia using JUDI using Flux using ArgParse, Test, Printf, Aqua using SegyIO, LinearAlgebra, Distributed, JOLI using TimerOutputs: TimerOutputs, @timeit Flux.Random.seed!(2022) ### Model tti = false viscoacoustic = false nsrc = 1 dt = 1f0 include(joinpath(JUDIPATH, "../test/seismic_utils.jl")) model, model0, dm = setup_model(tti, viscoacoustic, 4) m, m0 = model.m.data, model0.m.data q, srcGeometry, recGeometry, f0 = setup_geom(model; nsrc=nsrc, dt=dt) # Common op Pr = judiProjection(recGeometry) Ps = judiProjection(srcGeometry) ra = false stype = "Point" Pq = Ps opt = Options(return_array=ra, sum_padding=true, f0=f0) A_inv = judiModeling(model; options=opt) A_inv0 = judiModeling(model0; options=opt) # Operators F = Pr*A_inv*adjoint(Pq) F0 = Pr*A_inv0*adjoint(Pq) gs_inv = gradient(x -> norm(F(x)*q), m0) gs_inv1 = gradient(x -> norm(F(1f0*x)*q), m0) import Base.*; *(y::Float32, F::JUDI.LazyPropagation) = JUDI.LazyPropagation(F.post, F.F, *(y, F.q)); gs_inv2 = gradient(x -> norm(F(1f0*x)*q), m0) ```
mloubout commented 1 year ago

That's quite curious indeed i'll see if can figure out what's going on

mloubout commented 1 year ago

Well that's is baaaaaaaad, this is why people don't wanna use Julia for serious stuff.

When you do gs_inv2 = gradient(x -> norm(F(1f0*x)*q), m0) Zygote doesn't understand correctly that you want "only" the derivative w.r.t to m0, in part because it doesn't understand thinks. So it end up computing what you want, i.e d F(1*m0)*q / d m0 but because diff rules are defined for both left and right input for mul (and again since zygote always computes and evaluate everything) it also computes d F(1*m0)*q / d 1 which calls dot which calls eval_prop.

So there is not trivial way out of it except maybe having LazyPropagation store the result at its first evaluation so its only computed once (the above compute the same gradient twice)

ziyiyin97 commented 1 year ago

Could you enlighten me how (by code or something) you reach the conclusion here https://github.com/slimgroup/JUDI.jl/pull/167#issuecomment-1373145567 ? I am experiencing issue below and would like to check what went wrong ...

julia> gs_inv = gradient(() -> norm(F(1f0*m)*q), Flux.params(m))
[ Info: Assuming m to be squared slowness for judiDataSourceModeling{Float32, :forward}
Operator `forward` ran in 0.54 s
Operator `forward` ran in 0.48 s
Operator `gradient` ran in 0.34 s
Grads(...)

julia> gs_inv = gradient(() -> norm(F(m*1f0)*q), Flux.params(m))
[ Info: Assuming m to be squared slowness for judiDataSourceModeling{Float32, :forward}
Operator `forward` ran in 0.54 s
Operator `forward` ran in 0.53 s
Operator `gradient` ran in 0.34 s
Operator `forward` ran in 0.49 s
Operator `gradient` ran in 0.34 s
Grads(...)
mloubout commented 1 year ago

Debug every eval_prop to see which where it's called and what the inputs are. In that other case it was evaluated in dot then you can infer why and check that's undeed the gradient it computes by requesting it as a param

mloubout commented 1 year ago

Not sure where you are in the debug, but I can tell you that's it's not super trivial and the fix will require some proper design to extend it cleanly to this type of case. But i'll leave it to you to at least find what the issue is as an exercise.

ziyiyin97 commented 1 year ago

Thanks! Yes I agree this is not simple. I will pick it up some time later this week