Open roflmaostc opened 6 months ago
The following fails:
function ChainRulesCore.rrule(as::AngularSpectrum3, field) field_and_tuple = as(field) function as_pullback(ȳ) f̄ = NoTangent() y2 = ȳ fill!(as.buffer2, 0) # THIS LINE FAILS field_new = as.padding ? ∇set_center!(y2, as.buffer2, field, broadcast=true) : y2 field_imd = as.p * ifftshift!(as.buffer, field_new, (1, 2)) field_imd .*= conj.(as.HW) field_out = fftshift!(as.buffer2, inv(as.p) * field_imd, (1, 2)) field_out_cropped = as.padding ? crop_center(field_out, size(field), return_view=true) : field_out return f̄, field_out_cropped end return field_and_tuple, as_pullback end function ∇set_center!(dy, arr_large::AbstractArray{T, N}, arr_small::AbstractArray{T1, M}; broadcast=false) where {T, T1, M, N} @assert N ≥ M "Can't put a higher dimensional array in a lower dimensional one." if broadcast == false inds = ntuple(i -> begin a, b = get_indices_around_center(size(arr_large, i), size(arr_small, i)) a:b end, Val(N)) arr_large[inds..., ..] .= dy else inds = ntuple(i -> begin a, b = get_indices_around_center(size(arr_large, i), size(arr_small, i)) a:b end, Val(M)) # THIS LINE fails with broadcasting arr_large[inds..., ..] .= dy end return arr_large end
with
julia> include("test/angular_spectrum.jl") typeof(dy) = Tangent{Any, Tuple{Matrix{ComplexF64}, ZeroTangent}} Test gradient with Finite Differences: Error During Test at /home/fxw/.julia/dev/WaveOpticsPropagation.jl/test/angular_spectrum.jl:3 Got exception outside of a @test DimensionMismatch: array could not be broadcast to match destination Stacktrace: [1] check_broadcast_shape @ ./broadcast.jl:579 [inlined] [2] check_broadcast_axes @ ./broadcast.jl:582 [inlined] [3] instantiate @ ./broadcast.jl:309 [inlined] [4] materialize! @ ./broadcast.jl:914 [inlined] [5] materialize! @ ./broadcast.jl:911 [inlined] [6] ∇set_center!(dy::Tangent{Any, Tuple{Matrix{ComplexF64}, ZeroTangent}}, arr_large::Matrix{ComplexF64}, arr_small::Matrix{ComplexF64}; broadcast::Bool) @ WaveOpticsPropagation ~/.julia/dev/WaveOpticsPropagation.jl/src/utils.jl:248 [7] ∇set_center! @ ~/.julia/dev/WaveOpticsPropagation.jl/src/utils.jl:230 [inlined] [8] (::WaveOpticsPropagation.var"#as_pullback#166"{WaveOpticsPropagation.AngularSpectrum3{Matrix{ComplexF64}, Float64, FFTW.cFFTWPlan{ComplexF64, -1, true, 2, Tuple{Int64, Int64}}}, Matrix{ComplexF64}})(ȳ::Tangent{Any, Tuple{Matrix{ComplexF64}, ZeroTangent}}) @ WaveOpticsPropagation ~/.julia/dev/WaveOpticsPropagation.jl/src/angular_spectrum.jl:200 [9] (::Zygote.ZBack{WaveOpticsPropagation.var"#as_pullback#166"{WaveOpticsPropagation.AngularSpectrum3{Matrix{ComplexF64}, Float64, FFTW.cFFTWPlan{ComplexF64, -1, true, 2, Tuple{Int64, Int64}}}, Matrix{ComplexF64}}})(dy::Tuple{Matrix{ComplexF64}, Nothing}) @ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/chainrules.jl:211 [10] f_AS @ ~/.julia/dev/WaveOpticsPropagation.jl/test/angular_spectrum.jl:15 [inlined] [11] (::Zygote.Pullback{Tuple{var"#f_AS#132", Matrix{ComplexF64}}, Any})(Δ::Float64) @ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface2.jl:0 [12] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{var"#f_AS#132", Matrix{ComplexF64}}, Any}})(Δ::Float64) @ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface.jl:45 [13] gradient(f::Function, args::Matrix{ComplexF64}) @ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface.jl:97 [14] macro expansion @ ~/.julia/dev/WaveOpticsPropagation.jl/test/angular_spectrum.jl:17 [inlined] [15] macro expansion @ ~/.julia/juliaup/julia-1.10.0+0.x64.linux.gnu/share/julia/stdlib/v1.10/Test/src/Test.jl:1577 [inlined] [16] macro expansion @ ~/.julia/dev/WaveOpticsPropagation.jl/test/angular_spectrum.jl:4 [inlined] [17] macro expansion @ ~/.julia/juliaup/julia-1.10.0+0.x64.linux.gnu/share/julia/stdlib/v1.10/Test/src/Test.jl:1577 [inlined] [18] top-level scope @ ~/.julia/dev/WaveOpticsPropagation.jl/test/angular_spectrum.jl:3 [19] include(fname::String) @ Base.MainInclude ./client.jl:489 [20] top-level scope @ REPL[21]:1 [21] top-level scope @ ~/.julia/packages/CUDA/rXson/src/initialization.jl:208 [22] eval @ Core ./boot.jl:385 [inlined] [23] eval_user_input(ast::Any, backend::REPL.REPLBackend, mod::Module) @ REPL ~/.julia/juliaup/julia-1.10.0+0.x64.linux.gnu/share/julia/stdlib/v1.10/REPL/src/REPL.jl:150 [24] repl_backend_loop(backend::REPL.REPLBackend, get_module::Function) @ REPL ~/.julia/juliaup/julia-1.10.0+0.x64.linux.gnu/share/julia/stdlib/v1.10/REPL/src/REPL.jl:246 [25] start_repl_backend(backend::REPL.REPLBackend, consumer::Any; get_module::Function) @ REPL ~/.julia/juliaup/julia-1.10.0+0.x64.linux.gnu/share/julia/stdlib/v1.10/REPL/src/REPL.jl:231 [26] run_repl(repl::REPL.AbstractREPL, consumer::Any; backend_on_current_task::Bool, backend::Any) @ REPL ~/.julia/juliaup/julia-1.10.0+0.x64.linux.gnu/share/julia/stdlib/v1.10/REPL/src/REPL.jl:389 [27] run_repl(repl::REPL.AbstractREPL, consumer::Any) @ REPL ~/.julia/juliaup/julia-1.10.0+0.x64.linux.gnu/share/julia/stdlib/v1.10/REPL/src/REPL.jl:375 [28] (::Base.var"#1013#1015"{Bool, Bool, Bool})(REPL::Module) @ Base ./client.jl:432 [29] #invokelatest#2 @ Base ./essentials.jl:887 [inlined] [30] invokelatest @ Base ./essentials.jl:884 [inlined] [31] run_main_repl(interactive::Bool, quiet::Bool, banner::Bool, history_file::Bool, color_set::Bool) @ Base ./client.jl:416 [32] exec_options(opts::Base.JLOptions) @ Base ./client.jl:333 [33] _start() @ Base ./client.jl:552
The following fails:
with