slimgroup / JUDI.jl

Julia Devito inversion.
https://slimgroup.github.io/JUDI.jl
MIT License
96 stars 29 forks source link

Output size of rrule is not the same as the size of the variable #124

Closed ziyiyin97 closed 2 years ago

ziyiyin97 commented 2 years ago

MFE is here

using JUDI
using ArgParse, Test, Printf, Aqua
using SegyIO, LinearAlgebra, Distributed, JOLI
using TimerOutputs: TimerOutputs, @timeit
using Flux

# JUDI seismic utils
include(joinpath(JUDIPATH,"../test/seismic_utils.jl"))

### Model
nsrc = 1
dt = 1f0

model, model0, dm = setup_model(false, false, 4)
m, m0 = model.m.data, model0.m.data
q, srcGeometry, recGeometry, f0 = setup_geom(model; nsrc=nsrc, dt=dt)

# Common op
F = judiModeling(model, srcGeometry, recGeometry)
d_obs = F*q

g = gradient(Flux.params(m0)) do
    ϕ = .5f0*norm(F(m0, q) - d_obs)^2
    return ϕ
end

## problem
println(size(m0)) # (301, 151)
println(size(g.grads[m0])) # (45451, 1)

### MFE starts now
v0 = sqrt.(1f0./m0)

gv = gradient(Flux.params(v0)) do
    ϕ = .5f0*norm(F((1f0./v0).^2f0, q) - d_obs)^2
    return ϕ
end

As you can see in the first example, these two got different sizes

println(size(m0)) # (301, 151)
println(size(g.grads[m0])) # (45451, 1)

thus when working with v0 = sqrt.(1f0./m0), there is a shape mismatch error

ERROR: LoadError: DimensionMismatch("variable with size(x) == (301, 151) cannot have a gradient with size(dx) == (45451,)")
Stacktrace:
  [1] (::ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}})(dx::Vector{Float32})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/RbX5a/src/projection.jl:227
  [2] _project
    @ ~/.julia/packages/Zygote/ytjqm/src/compiler/chainrules.jl:183 [inlined]
  [3] unbroadcast(x::Matrix{Float32}, x̄::Vector{Float32})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/lib/broadcast.jl:51
  [4] map
    @ ./tuple.jl:247 [inlined]
  [5] (::Zygote.var"#∇broadcasted#1106"{Tuple{Matrix{Float32}, Float32}, Matrix{Tuple{Float32, Zygote.ZBack{ChainRules.var"#power_pullback#1212"{Float32, Float32, ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Float32}}}}, Val{3}})(ȳ::PhysicalParameter{Float32})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/lib/broadcast.jl:198
  [6] #4012#back
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
  [7] #212
    @ ~/.julia/packages/Zygote/ytjqm/src/lib/lib.jl:203 [inlined]
  [8] #1750#back
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
  [9] Pullback
    @ ./broadcast.jl:1303 [inlined]
 [10] (::typeof(∂(broadcasted)))(Δ::PhysicalParameter{Float32})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [11] Pullback
    @ ~/Desktop/FNO4CO2/scripts/MFE.jl:35 [inlined]
 [12] (::typeof(∂(#8)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [13] (::Zygote.var"#93#94"{Zygote.Params{Zygote.Buffer{Any, Vector{Any}}}, typeof(∂(#8)), Zygote.Context})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface.jl:357
 [14] gradient(f::Function, args::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface.jl:76
 [15] top-level scope
    @ ~/Desktop/FNO4CO2/scripts/MFE.jl:34
 [16] include(fname::String)
    @ Base.MainInclude ./client.jl:451
 [17] top-level scope
    @ REPL[1]:1
in expression starting at /Users/francisyin/Desktop/FNO4CO2/scripts/MFE.jl:34
ziyiyin97 commented 2 years ago

BTW there isn't any problem if


options = Options(return_array=true)
F = judiModeling(model, srcGeometry, recGeometry; options=options)
d_obs = F*q
d_obs = F.options.return_array ? reshape(d_obs, F.rInterpolation, F.model; with_batch=true) : d_obs
ziyiyin97 commented 2 years ago

A bit tricky because the shape of dm is uncertain here https://github.com/slimgroup/JUDI.jl/blob/6c539b1cb002f54bd4d7c59bcbf55cab27d36f5d/src/rrules.jl#L25 any thought?

mloubout commented 2 years ago

You either work with a vector or a PhysicalParameter not an array.

ziyiyin97 commented 2 years ago

Thanks. Precisely speaking we need to work with a Matrix in size (prod(n), 1) because https://github.com/slimgroup/JUDI.jl/blob/5e3b99f8c90298516170685e0f22129650e2f822/src/TimeModeling/Types/ModelStructure.jl#L94 -- any reason why PhysicalParameter as an AbstractVector has a size in length of 2, i.e. why not (prod(n),)?

mloubout commented 2 years ago

Should be (prod(A.n),)