JuliaPhysics / WaveOpticsPropagation.jl

Propagate waves efficiently, optically, physically, differentiably with Julia Lang.
MIT License
11 stars 1 forks source link

ChainRules with Tagent #13

Open roflmaostc opened 6 months ago

roflmaostc commented 6 months ago

The following fails:

function ChainRulesCore.rrule(as::AngularSpectrum3, field)
    field_and_tuple = as(field) 
    function as_pullback(ȳ)
        f̄ = NoTangent()
        y2 = ȳ

        fill!(as.buffer2, 0)

        # THIS LINE FAILS
        field_new = as.padding ? ∇set_center!(y2, as.buffer2, field, broadcast=true) : y2  
        field_imd = as.p * ifftshift!(as.buffer, field_new, (1, 2)) 
        field_imd .*= conj.(as.HW)
        field_out = fftshift!(as.buffer2, inv(as.p) * field_imd, (1, 2)) 
        field_out_cropped = as.padding ? crop_center(field_out, size(field), return_view=true) : field_out
        return f̄, field_out_cropped 
    end 
    return field_and_tuple, as_pullback
end

function ∇set_center!(dy, arr_large::AbstractArray{T, N}, arr_small::AbstractArray{T1, M};
                     broadcast=false) where {T, T1, M, N}
    @assert N ≥ M "Can't put a higher dimensional array in a lower dimensional one."

    if broadcast == false
        inds = ntuple(i -> begin
                        a, b = get_indices_around_center(size(arr_large, i), size(arr_small, i))
                        a:b
                      end,
                      Val(N))
        arr_large[inds..., ..] .= dy
    else
        inds = ntuple(i -> begin
                        a, b = get_indices_around_center(size(arr_large, i), size(arr_small, i))
                        a:b
                      end,
                      Val(M))
        # THIS LINE fails with broadcasting
        arr_large[inds..., ..] .= dy
    end

    return arr_large
end

with

julia> include("test/angular_spectrum.jl")
typeof(dy) = Tangent{Any, Tuple{Matrix{ComplexF64}, ZeroTangent}}
Test gradient with Finite Differences: Error During Test at /home/fxw/.julia/dev/WaveOpticsPropagation.jl/test/angular_spectrum.jl:3
  Got exception outside of a @test
  DimensionMismatch: array could not be broadcast to match destination
  Stacktrace:
    [1] check_broadcast_shape
      @ ./broadcast.jl:579 [inlined]
    [2] check_broadcast_axes
      @ ./broadcast.jl:582 [inlined]
    [3] instantiate
      @ ./broadcast.jl:309 [inlined]
    [4] materialize!
      @ ./broadcast.jl:914 [inlined]
    [5] materialize!
      @ ./broadcast.jl:911 [inlined]
    [6] ∇set_center!(dy::Tangent{Any, Tuple{Matrix{ComplexF64}, ZeroTangent}}, arr_large::Matrix{ComplexF64}, arr_small::Matrix{ComplexF64}; broadcast::Bool)
      @ WaveOpticsPropagation ~/.julia/dev/WaveOpticsPropagation.jl/src/utils.jl:248
    [7] ∇set_center!
      @ ~/.julia/dev/WaveOpticsPropagation.jl/src/utils.jl:230 [inlined]
    [8] (::WaveOpticsPropagation.var"#as_pullback#166"{WaveOpticsPropagation.AngularSpectrum3{Matrix{ComplexF64}, Float64, FFTW.cFFTWPlan{ComplexF64, -1, true, 2, Tuple{Int64, Int64}}}, Matrix{ComplexF64}})(ȳ::Tangent{Any, Tuple{Matrix{ComplexF64}, ZeroTangent}})
      @ WaveOpticsPropagation ~/.julia/dev/WaveOpticsPropagation.jl/src/angular_spectrum.jl:200
    [9] (::Zygote.ZBack{WaveOpticsPropagation.var"#as_pullback#166"{WaveOpticsPropagation.AngularSpectrum3{Matrix{ComplexF64}, Float64, FFTW.cFFTWPlan{ComplexF64, -1, true, 2, Tuple{Int64, Int64}}}, Matrix{ComplexF64}}})(dy::Tuple{Matrix{ComplexF64}, Nothing})
      @ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/chainrules.jl:211
   [10] f_AS
      @ ~/.julia/dev/WaveOpticsPropagation.jl/test/angular_spectrum.jl:15 [inlined]
   [11] (::Zygote.Pullback{Tuple{var"#f_AS#132", Matrix{ComplexF64}}, Any})(Δ::Float64)
      @ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface2.jl:0
   [12] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{var"#f_AS#132", Matrix{ComplexF64}}, Any}})(Δ::Float64)
      @ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface.jl:45
   [13] gradient(f::Function, args::Matrix{ComplexF64})
      @ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface.jl:97
   [14] macro expansion
      @ ~/.julia/dev/WaveOpticsPropagation.jl/test/angular_spectrum.jl:17 [inlined]
   [15] macro expansion
      @ ~/.julia/juliaup/julia-1.10.0+0.x64.linux.gnu/share/julia/stdlib/v1.10/Test/src/Test.jl:1577 [inlined]
   [16] macro expansion
      @ ~/.julia/dev/WaveOpticsPropagation.jl/test/angular_spectrum.jl:4 [inlined]
   [17] macro expansion
      @ ~/.julia/juliaup/julia-1.10.0+0.x64.linux.gnu/share/julia/stdlib/v1.10/Test/src/Test.jl:1577 [inlined]
   [18] top-level scope
      @ ~/.julia/dev/WaveOpticsPropagation.jl/test/angular_spectrum.jl:3
   [19] include(fname::String)
      @ Base.MainInclude ./client.jl:489
   [20] top-level scope
      @ REPL[21]:1
   [21] top-level scope
      @ ~/.julia/packages/CUDA/rXson/src/initialization.jl:208
   [22] eval
      @ Core ./boot.jl:385 [inlined]
   [23] eval_user_input(ast::Any, backend::REPL.REPLBackend, mod::Module)
      @ REPL ~/.julia/juliaup/julia-1.10.0+0.x64.linux.gnu/share/julia/stdlib/v1.10/REPL/src/REPL.jl:150
   [24] repl_backend_loop(backend::REPL.REPLBackend, get_module::Function)
      @ REPL ~/.julia/juliaup/julia-1.10.0+0.x64.linux.gnu/share/julia/stdlib/v1.10/REPL/src/REPL.jl:246
   [25] start_repl_backend(backend::REPL.REPLBackend, consumer::Any; get_module::Function)
      @ REPL ~/.julia/juliaup/julia-1.10.0+0.x64.linux.gnu/share/julia/stdlib/v1.10/REPL/src/REPL.jl:231
   [26] run_repl(repl::REPL.AbstractREPL, consumer::Any; backend_on_current_task::Bool, backend::Any)
      @ REPL ~/.julia/juliaup/julia-1.10.0+0.x64.linux.gnu/share/julia/stdlib/v1.10/REPL/src/REPL.jl:389
   [27] run_repl(repl::REPL.AbstractREPL, consumer::Any)
      @ REPL ~/.julia/juliaup/julia-1.10.0+0.x64.linux.gnu/share/julia/stdlib/v1.10/REPL/src/REPL.jl:375
   [28] (::Base.var"#1013#1015"{Bool, Bool, Bool})(REPL::Module)
      @ Base ./client.jl:432
   [29] #invokelatest#2
      @ Base ./essentials.jl:887 [inlined]
   [30] invokelatest
      @ Base ./essentials.jl:884 [inlined]
   [31] run_main_repl(interactive::Bool, quiet::Bool, banner::Bool, history_file::Bool, color_set::Bool)
      @ Base ./client.jl:416
   [32] exec_options(opts::Base.JLOptions)
      @ Base ./client.jl:333
   [33] _start()
      @ Base ./client.jl:552