FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

`abs2` of complex CUDA array fails with `Zygote.gradient` #961

Closed roflmaostc closed 1 year ago

roflmaostc commented 3 years ago

Hey,

when applying abs2 to a complex CUDA array I get an ERROR: MethodError: no method matching iterate(::Nothing). I'm using CUDA 3.1.0, Julia 1.6.1 and Zygote 0.6.10. But I also tried it on Julia 1.5.4, CUDA v2.4.0, Zygote v0.5.0 so it must be not a recent introduced issue.

See the MWE below:

julia> using Zygote, CUDA

julia> x = rand(ComplexF32, (2,2))
2×2 Matrix{ComplexF32}:
 0.0598111+0.678913im  0.767138+0.77825im
  0.548067+0.98656im   0.306103+0.166084im

julia> x_c = CuArray(x);

julia> f(x) = sum(abs2.(x))
f (generic function with 1 method)

julia> g(x) = sum(real(x .* conj.(x)))
g (generic function with 1 method)

julia> f(x) ≈ f(x_c) ≈ g(x) ≈ g(x_c)
true

julia> Zygote.gradient(f, x)
(ComplexF32[0.11962223f0 + 1.3578255f0im 1.534275f0 + 1.5565007f0im; 1.0961342f0 + 1.9731205f0im 0.61220574f0 + 0.33216715f0im],)

julia> Zygote.gradient(g, x)
(ComplexF32[0.11962223f0 + 1.3578255f0im 1.534275f0 + 1.5565007f0im; 1.0961342f0 + 1.9731205f0im 0.61220574f0 + 0.33216715f0im],)

julia> Zygote.gradient(f, x_c)
ERROR: MethodError: no method matching iterate(::Nothing)
Closest candidates are:
  iterate(::Union{LinRange, StepRangeLen}) at range.jl:664
  iterate(::Union{LinRange, StepRangeLen}, ::Int64) at range.jl:664
  iterate(::T) where T<:Union{Base.KeySet{var"#s79", var"#s78"} where {var"#s79", var"#s78"<:Dict}, Base.ValueIterator{var"#s77"} where var"#s77"<:Dict} at dict.jl:693
  ...
Stacktrace:
  [1] (::Zygote.var"#1209#1210"{Zygote.var"#1104#1108"})(ȳ::CuArray{Float32, 2})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/lib/broadcast.jl:231
  [2] (::Zygote.var"#577#back#1211"{Zygote.var"#1209#1210"{Zygote.var"#1104#1108"}})(Δ::CuArray{Float32, 2})
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [3] (::Zygote.var"#180#181"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#577#back#1211"{Zygote.var"#1209#1210"{Zygote.var"#1104#1108"}}})(Δ::CuArray{Float32, 2})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/lib/lib.jl:194
  [4] (::Zygote.var"#1689#back#182"{Zygote.var"#180#181"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#577#back#1211"{Zygote.var"#1209#1210"{Zygote.var"#1104#1108"}}}})(Δ::CuArray{Float32, 2})
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [5] Pullback
    @ ./broadcast.jl:1309 [inlined]
  [6] Pullback
    @ ./REPL[18]:1 [inlined]
  [7] (::typeof(∂(f)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
  [8] (::Zygote.var"#41#42"{typeof(∂(f))})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:41
  [9] gradient(f::Function, args::CuArray{ComplexF32, 2})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:59
 [10] top-level scope
    @ REPL[24]:1
 [11] top-level scope
    @ ~/.julia/packages/CUDA/k52QH/src/initialization.jl:81

julia> Zygote.gradient(g, x_c)
(ComplexF32[0.11962223f0 + 1.3578255f0im 1.534275f0 + 1.5565007f0im; 1.0961342f0 + 1.9731205f0im 0.61220574f0 + 0.33216715f0im],)
Manifest.toml # This file is machine-generated - editing it directly is not advised [[AbstractFFTs]] deps = ["LinearAlgebra"] git-tree-sha1 = "485ee0867925449198280d4af84bdb46a2a404d0" uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" version = "1.0.1" [[Adapt]] deps = ["LinearAlgebra"] git-tree-sha1 = "f1b523983a58802c4695851926203b36e28f09db" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" version = "3.3.0" [[ArgTools]] uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" [[Artifacts]] uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" [[BFloat16s]] deps = ["LinearAlgebra", "Test"] git-tree-sha1 = "4af69e205efc343068dc8722b8dfec1ade89254a" uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" version = "0.1.0" [[Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" [[CEnum]] git-tree-sha1 = "215a9aa4a1f23fbd05b92769fdd62559488d70e9" uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" version = "0.4.1" [[CUDA]] deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "DataStructures", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "MacroTools", "Memoize", "Printf", "Random", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "TimerOutputs"] git-tree-sha1 = "d4fa6486e94c4087f1d081d7be2d501a170bd51d" uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" version = "3.1.0" [[ChainRules]] deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Random", "Reexport", "Requires", "Statistics"] git-tree-sha1 = "1f410fba5c04d03ab712f348f1542e6059376547" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" version = "0.7.61" [[ChainRulesCore]] deps = ["Compat", "LinearAlgebra", "SparseArrays"] git-tree-sha1 = "bd0cc939d94b8bd736dce5bbbe0d635db9f94af7" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" version = "0.9.41" [[CommonSubexpressions]] deps = ["MacroTools", "Test"] git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" version = "0.3.0" [[Compat]] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] git-tree-sha1 = "ac4132ad78082518ec2037ae5770b6e796f7f956" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" version = "3.27.0" [[CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" [[DataStructures]] deps = ["Compat", "InteractiveUtils", "OrderedCollections"] git-tree-sha1 = "4437b64df1e0adccc3e5d1adbc3ac741095e4677" uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" version = "0.18.9" [[Dates]] deps = ["Printf"] uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" [[DelimitedFiles]] deps = ["Mmap"] uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" [[DiffResults]] deps = ["StaticArrays"] git-tree-sha1 = "c18e98cba888c6c25d1c3b048e4b3380ca956805" uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" version = "1.0.3" [[DiffRules]] deps = ["NaNMath", "Random", "SpecialFunctions"] git-tree-sha1 = "214c3fcac57755cfda163d91c58893a8723f93e9" uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" version = "1.0.2" [[Distributed]] deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[Downloads]] deps = ["ArgTools", "LibCURL", "NetworkOptions"] uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" [[ExprTools]] git-tree-sha1 = "10407a39b87f29d47ebaca8edbc75d7c302ff93e" uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" version = "0.1.3" [[FFTW]] deps = ["AbstractFFTs", "FFTW_jll", "LinearAlgebra", "MKL_jll", "Preferences", "Reexport"] git-tree-sha1 = "1dc6ca6ad69eb9beadd3ce82b90910f4fa63d7c3" uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" version = "1.4.0" [[FFTW_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] git-tree-sha1 = "5a0d4b6a22a34d17d53543bd124f4b08ed78e8b0" uuid = "f5851436-0d7a-5f13-b9de-f02708fd171a" version = "3.3.9+7" [[FillArrays]] deps = ["LinearAlgebra", "Random", "SparseArrays"] git-tree-sha1 = "31939159aeb8ffad1d4d8ee44d07f8558273120a" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" version = "0.11.7" [[ForwardDiff]] deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "NaNMath", "Printf", "Random", "SpecialFunctions", "StaticArrays"] git-tree-sha1 = "e2af66012e08966366a43251e1fd421522908be6" uuid = "f6369f11-7733-5829-9624-2563aa707210" version = "0.10.18" [[GPUArrays]] deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"] git-tree-sha1 = "3e10e95ddc385e1589c27b1a58f21bf3008b559c" uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" version = "6.3.0" [[GPUCompiler]] deps = ["DataStructures", "ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "Serialization", "TimerOutputs", "UUIDs"] git-tree-sha1 = "6eadd2321dc3ac0fc9d530ab01c2caa7fe5d74c6" uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" version = "0.11.4" [[IRTools]] deps = ["InteractiveUtils", "MacroTools", "Test"] git-tree-sha1 = "c67e7515a11f726f44083e74f218d134396d6510" uuid = "7869d1d1-7146-5819-86e3-90919afe41df" version = "0.4.2" [[IntelOpenMP_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] git-tree-sha1 = "d979e54b71da82f3a65b62553da4fc3d18c9004c" uuid = "1d5cc7b8-4909-519e-a0f8-d0f5ad9712d0" version = "2018.0.3+2" [[InteractiveUtils]] deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" [[JLLWrappers]] deps = ["Preferences"] git-tree-sha1 = "642a199af8b68253517b80bd3bfd17eb4e84df6e" uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" version = "1.3.0" [[LLVM]] deps = ["CEnum", "Libdl", "Printf", "Unicode"] git-tree-sha1 = "b616937c31337576360cb9fb872ec7633af7b194" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" version = "3.6.0" [[LazyArtifacts]] deps = ["Artifacts", "Pkg"] uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" [[LibCURL]] deps = ["LibCURL_jll", "MozillaCACerts_jll"] uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" [[LibCURL_jll]] deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" [[LibGit2]] deps = ["Base64", "NetworkOptions", "Printf", "SHA"] uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" [[LibSSH2_jll]] deps = ["Artifacts", "Libdl", "MbedTLS_jll"] uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" [[Libdl]] uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" [[LinearAlgebra]] deps = ["Libdl"] uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [[Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" [[MKL_jll]] deps = ["Artifacts", "IntelOpenMP_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] git-tree-sha1 = "c253236b0ed414624b083e6b72bfe891fbd2c7af" uuid = "856f044c-d86e-5d09-b602-aeab76dc8ba7" version = "2021.1.1+1" [[MacroTools]] deps = ["Markdown", "Random"] git-tree-sha1 = "6a8a2a625ab0dea913aba95c11370589e0239ff0" uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" version = "0.5.6" [[Markdown]] deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" [[MbedTLS_jll]] deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" [[Memoize]] deps = ["MacroTools"] git-tree-sha1 = "2b1dfcba103de714d31c033b5dacc2e4a12c7caa" uuid = "c03570c3-d221-55d1-a50c-7939bbd78826" version = "0.4.4" [[Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" [[MozillaCACerts_jll]] uuid = "14a3606d-f60d-562e-9121-12d972cd8159" [[NaNMath]] git-tree-sha1 = "bfe47e760d60b82b66b61d2d44128b62e3a369fb" uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" version = "0.3.5" [[NetworkOptions]] uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" [[OpenSpecFun_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] git-tree-sha1 = "b9b8b8ed236998f91143938a760c2112dceeb2b4" uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" version = "0.5.4+0" [[OrderedCollections]] git-tree-sha1 = "4fa2ba51070ec13fcc7517db714445b4ab986bdf" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" version = "1.4.0" [[Pkg]] deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" [[Preferences]] deps = ["TOML"] git-tree-sha1 = "ea79e4c9077208cd3bc5d29631a26bc0cff78902" uuid = "21216c6a-2e73-6563-6e65-726566657250" version = "1.2.1" [[Printf]] deps = ["Unicode"] uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" [[REPL]] deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" [[Random]] deps = ["Serialization"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [[RandomNumbers]] deps = ["Random", "Requires"] git-tree-sha1 = "441e6fc35597524ada7f85e13df1f4e10137d16f" uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143" version = "1.4.0" [[Reexport]] git-tree-sha1 = "57d8440b0c7d98fc4f889e478e80f268d534c9d5" uuid = "189a3867-3050-52da-a836-e630ba90ab69" version = "1.0.0" [[Requires]] deps = ["UUIDs"] git-tree-sha1 = "4036a3bd08ac7e968e27c203d45f5fff15020621" uuid = "ae029012-a4dd-5104-9daa-d747884805df" version = "1.1.3" [[SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" [[Scratch]] deps = ["Dates"] git-tree-sha1 = "ad4b278adb62d185bbcb6864dc24959ab0627bf6" uuid = "6c6a2e73-6563-6170-7368-637461726353" version = "1.0.3" [[Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" [[SharedArrays]] deps = ["Distributed", "Mmap", "Random", "Serialization"] uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" [[Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" [[SparseArrays]] deps = ["LinearAlgebra", "Random"] uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [[SpecialFunctions]] deps = ["ChainRulesCore", "OpenSpecFun_jll"] git-tree-sha1 = "5919936c0e92cff40e57d0ddf0ceb667d42e5902" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" version = "1.3.0" [[StaticArrays]] deps = ["LinearAlgebra", "Random", "Statistics"] git-tree-sha1 = "2653e9c769343808781a8bd5010ee7a17c01152e" uuid = "90137ffa-7385-5640-81b9-e52037218182" version = "1.1.2" [[Statistics]] deps = ["LinearAlgebra", "SparseArrays"] uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [[TOML]] deps = ["Dates"] uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" [[Tar]] deps = ["ArgTools", "SHA"] uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" [[Test]] deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [[TimerOutputs]] deps = ["Printf"] git-tree-sha1 = "32cdbe6cd2d214c25a0b88f985c9e0092877c236" uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" version = "0.5.8" [[UUIDs]] deps = ["Random", "SHA"] uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [[Unicode]] uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" [[Zlib_jll]] deps = ["Libdl"] uuid = "83775a58-1f1d-513f-b197-d71354ab007a" [[Zygote]] deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"] git-tree-sha1 = "927209c83efa62256788a9880c191774c07c5b51" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" version = "0.6.10" [[ZygoteRules]] deps = ["MacroTools"] git-tree-sha1 = "9e7a1e8ca60b742e508a315c17eef5211e7fbfd7" uuid = "700de1a5-db45-46bc-99cf-38207098b444" version = "0.2.1" [[nghttp2_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" [[p7zip_jll]] deps = ["Artifacts", "Libdl"] uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"

Thanks a lot!

Felix

roflmaostc commented 3 years ago

Uh, I believe the problem is more general than posted above. Several basic functions are broken:


julia> f_exp(x) = sum(real(exp.(x)))
f_exp (generic function with 1 method)

julia> Zygote.gradient(f_exp, x_c)
ERROR: MethodError: no method matching iterate(::Nothing)
Closest candidates are:
  iterate(::Union{LinRange, StepRangeLen}) at range.jl:664
  iterate(::Union{LinRange, StepRangeLen}, ::Int64) at range.jl:664
  iterate(::T) where T<:Union{Base.KeySet{var"#s79", var"#s78"} where {var"#s79", var"#s78"<:Dict}, Base.ValueIterator{var"#s77"} where var"#s77"<:Dict} at dict.jl:693
  ...
Stacktrace:
  [1] (::Zygote.var"#1209#1210"{Zygote.var"#1104#1108"})(ȳ::CuArray{Float32, 2})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/lib/broadcast.jl:231
  [2] (::Zygote.var"#577#back#1211"{Zygote.var"#1209#1210"{Zygote.var"#1104#1108"}})(Δ::CuArray{Float32, 2})
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [3] (::Zygote.var"#180#181"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#577#back#1211"{Zygote.var"#1209#1210"{Zygote.var"#1104#1108"}}})(Δ::CuArray{Float32, 2})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/lib/lib.jl:194
  [4] (::Zygote.var"#1689#back#182"{Zygote.var"#180#181"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#577#back#1211"{Zygote.var"#1209#1210"{Zygote.var"#1104#1108"}}}})(Δ::CuArray{Float32, 2})
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [5] Pullback
    @ ./broadcast.jl:1309 [inlined]
  [6] Pullback
    @ ./REPL[19]:1 [inlined]
  [7] (::typeof(∂(f_exp)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
  [8] (::Zygote.var"#41#42"{typeof(∂(f_exp))})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:41
  [9] gradient(f::Function, args::CuArray{ComplexF32, 2})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:59
 [10] top-level scope
    @ REPL[20]:1
 [11] top-level scope
    @ ~/.julia/packages/CUDA/k52QH/src/initialization.jl:81

julia> f_abs(x) = sum(real(abs.(x)))
f_abs (generic function with 1 method)

julia> Zygote.gradient(f_abs, x_c)
ERROR: MethodError: no method matching iterate(::Nothing)
Closest candidates are:
  iterate(::Union{LinRange, StepRangeLen}) at range.jl:664
  iterate(::Union{LinRange, StepRangeLen}, ::Int64) at range.jl:664
  iterate(::T) where T<:Union{Base.KeySet{var"#s79", var"#s78"} where {var"#s79", var"#s78"<:Dict}, Base.ValueIterator{var"#s77"} where var"#s77"<:Dict} at dict.jl:693
  ...
Stacktrace:
  [1] (::Zygote.var"#1209#1210"{Zygote.var"#1104#1108"})(ȳ::CuArray{Float32, 2})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/lib/broadcast.jl:231
  [2] (::Zygote.var"#577#back#1211"{Zygote.var"#1209#1210"{Zygote.var"#1104#1108"}})(Δ::CuArray{Float32, 2})
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [3] (::Zygote.var"#180#181"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#577#back#1211"{Zygote.var"#1209#1210"{Zygote.var"#1104#1108"}}})(Δ::CuArray{Float32, 2})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/lib/lib.jl:194
  [4] (::Zygote.var"#1689#back#182"{Zygote.var"#180#181"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#577#back#1211"{Zygote.var"#1209#1210"{Zygote.var"#1104#1108"}}}})(Δ::CuArray{Float32, 2})
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [5] Pullback
    @ ./broadcast.jl:1309 [inlined]
  [6] Pullback
    @ ./REPL[27]:1 [inlined]
  [7] (::typeof(∂(f_abs)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
  [8] (::Zygote.var"#41#42"{typeof(∂(f_abs))})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:41
  [9] gradient(f::Function, args::CuArray{ComplexF32, 2})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:59
 [10] top-level scope
    @ REPL[28]:1
 [11] top-level scope
    @ ~/.julia/packages/CUDA/k52QH/src/initialization.jl:81
DhairyaLGandhi commented 3 years ago

Interesting, it's usually the complex numbers right? We might be better off making sure our adjoints can handle complex numbers properly. Some of it prolly needs fixing in ChainRules too. Ideally we'd be able to repro without relying on cuda, is that the case?

roflmaostc commented 3 years ago

Without CUDA it seems to work fine, I haven't encountered such an issue yet. Above the examples are with and without CUDA. The only one that fails, is Zygote + Array{<:Complex} + CUDA.

ToucheSir commented 3 years ago

The relevant lines: https://github.com/FluxML/Zygote.jl/blob/v0.6.10/src/lib/broadcast.jl#L195-L233

It appears broadcasting on complex numbers hits the early bailout in https://github.com/FluxML/Zygote.jl/blob/v0.6.10/src/lib/broadcast.jl#L213, returning a back function that itself returns none (and thus can't be splatted). I would assume this is not the intended behaviour and that at least some value should be propogated?

roflmaostc commented 3 years ago

Hm, I tried to do some debugging.

T = Broadcast.combine_eltypes(f, args) is always Float32 independent whether we have a CuArray{<:Complex} or CuArray{<:Real]}. However, that T is not used anywhere.

I'm not sure, but this looks strange, doesn't it?

julia> Zygote.dual_function(abs2).(CuArray(randn(Float32, (2,2))))
2×2 CuArray{Dual{Nothing, Float32, 1}, 2}:
 Dual{Nothing}(0.0782065,-0.559308)  Dual{Nothing}(0.000581367,-0.0482231)
 Dual{Nothing}(0.189928,-0.871615)   Dual{Nothing}(0.37119,1.21851)

julia> Zygote.dual_function(abs2).(CuArray(randn(ComplexF32, (2,2))))
2×2 CuArray{Float32, 2}:
 2.71749   0.407771
 0.735002  0.240149

The following

 @adjoint broadcasted(::typeof(abs2), x::CuArray) = 
            abs2.(x), a -> (nothing, 2 .* a .* x) # copied but replaced Numeric with CuArray

also seems to fix the issue, I mean that works, but is not a general solution. There must be definitely a general pattern for that...

julia> using Zygote, CUDA

julia> using Zygote: @adjoint, Numeric, broadcasted

julia> x_c = CUDA.rand(ComplexF32, 2, 2)
2×2 CuArray{ComplexF32, 2}:
 0.508057+0.580545im  0.846364+0.0125523im
 0.965085+0.773818im  0.962015+0.289069im

julia> x = Array(x_c);

julia> Zygote.gradient(x -> sum(abs2.(x)), x)
(ComplexF32[1.016113f0 + 1.1610899f0im 1.6927286f0 + 0.025104642f0im; 1.9301703f0 + 1.5476352f0im 1.9240301f0 + 0.5781378f0im],)

julia> Zygote.gradient(x -> sum(abs2.(x)), x_c)  # old error we know
ERROR: MethodError: no method matching iterate(::Nothing)
Closest candidates are:
  iterate(::Union{LinRange, StepRangeLen}) at range.jl:664
 [...]

julia> @adjoint broadcasted(::typeof(abs2), x::Numeric) =
         abs2.(x), a -> (nothing, 2 .* a .* x)

julia> Zygote.gradient(x -> sum(abs2.(x)), x_c)
(ComplexF32[1.016113f0 + 1.1610899f0im 1.6927286f0 + 0.025104642f0im; 1.9301703f0 + 1.5476352f0im 1.9240301f0 + 0.5781378f0im],)
roflmaostc commented 3 years ago

Hm, I'm confused why the code distinguishes between complex and real numbers.

DualNumbers.jl doesn't

Edit: OK reading some discussions, that's not exactly the same as in ForwardDiff. I believe that's beyond my understanding of Zygote etc. atm.

mcabbott commented 3 years ago

ForwardDiff's Dual numbers should work with complex numbers, but the way they are produced and consumed would need to change. The functions are a bit sloppy, they add a dual perturbation to real numbers, and ignore other types, it would be much better to give an error on plausibly differentiable types which can't be handled.

Anyway, a first look:

using Zygote
gradient(x -> sum(sqrt.(x)), [1,2,3])

y, b = Zygote.broadcast_forward(sqrt, [1,2,3]) # method used for CuArrays
b([1,1,1])  # same as normal

y1, b1 = pullback(x -> abs.(x), [1,2+im,3-im])
y2, b2 = Zygote.broadcast_forward(abs, [1,2+im,3-im])
b1([1,1,1])
b2([1,1,1]) # nothing

@eval Zygote dual(x::Complex, p) = Complex(Dual(real(x), p), imag(x))

b2([1,1,1]) # on re-running, now has the real parts of the sensitivity

So I think you need to do something like this, with 2N perturbations when there are complex numbers present (untested!)

dual(x, p, pc=()) = x
dual(x::Real, p, pc=()) = Dual(x, p)
dual(x::Complex, p, pc) = Complex(Dual(real(x), p), Dual(imag(x), pc))

function dual_function(f::F) where F
  function dual_f(args::Vararg{Any,N}) where N
    if any(a isa Complex for a in args)
      ds = map(args, ntuple(identity, Val(N))) do x, i
        dual(x, ntuple(j -> i==j, Val(2N)), ntuple(j -> N+i==j, Val(2N)))
      end
      return f(ds...)
    else
      ds = map(args, ntuple(identity, Val(N))) do x, i
        dual(x, ntuple(j -> i==j, Val(N)))
      end
      return f(ds...)
  end
end

And then broadcast_forward needs to extract these.

That's for the general case, which would also now be useful for broadcasting complex numbers on the CPU.

But in the meantime, adding the @adjoint rule you suggest for abs2 would not be crazy, there are already rules for real, conj etc. here: https://github.com/FluxML/Zygote.jl/blob/master/src/lib/broadcast.jl#L108-L115

roflmaostc commented 2 years ago

Hi, since this is still an issue I usually work around by defining custom adjoints. But I'm confused by that one:

Is that rule wrong?

julia> using Zygote, CUDA

julia> using Zygote:@adjoint, broadcasted

(jl_bqEEuk) pkg> status
Status `/tmp/jl_bqEEuk/Project.toml`
  [052768ef] CUDA v3.12.0
  [e88e6eb3] Zygote v0.6.49

julia> @adjoint broadcasted(::typeof(exp), x::CuArray) = 
                   exp.(x), a -> (nothing, exp.(x) .* a)

julia> Zygote.gradient(x -> sum(real.(exp.(x))), Array([1.1im * pi]))
(ComplexF64[-0.9510565162951535 + 0.30901699437494773im],)

julia> Zygote.gradient(x -> sum(real.(exp.(x))), CuArray([1.1im * pi]))
(ComplexF64[-0.9510565162951535 - 0.30901699437494773im],)
roflmaostc commented 2 years ago

So the correct rule apparently is:

julia> @adjoint broadcasted(::typeof(exp), x::CuArray) = 
                   exp.(x), a -> (nothing, exp.(conj.(x)) .* a)

julia> Zygote.gradient(x -> sum(real.(exp.(x))), CuArray([1.1im * pi]))

(ComplexF64[-0.9510565162951535 + 0.30901699437494773im],)

Where is this rule defined? I couldn't not really find it in the jungle of Zygote/ChainRules. I'm still not really familiar with the correct terminology, etc.

mcabbott commented 2 years ago

Yes I think that's right.

There is no such rule in Zygote. In the Array case I believe it will call the most generic fallback path, which is here: https://github.com/FluxML/Zygote.jl/blob/master/src/lib/broadcast.jl#L197-L206 .

roflmaostc commented 2 years ago

Thanks a lot the hint!

And where is the rule for exp defined? I'm still kind of confused about the conj.

mcabbott commented 2 years ago

The rule for exp(x::Complex) will I think come from ChainRules.

CR now has broadcasting rules too, BTW, but Zygote doesn't use them (yet, or ever, not sure).

I'm frequently confused about the conj but there are essentially two conventions for what gradient(real∘f, x+iy) could mean, and Zygote picked one of them.

roflmaostc commented 2 years ago

Yeah I thought so too but the only one I found is:

src/rulesets/Base/fastmath_able.jl:        @scalar_rule exp(x) Ω

Does the @scalar_rule inserts the conj at the right positions?

mcabbott commented 2 years ago

Yes, it should.

roflmaostc commented 2 years ago

Ok, thanks! That's interesting to know :)

CarloLucibello commented 1 year ago

closed #1324