JuliaDiff / ReverseDiff.jl

Reverse Mode Automatic Differentiation for Julia
Other
348 stars 57 forks source link

Handle matrix times matrix = vector case #227

Closed dkarrasch closed 1 year ago

dkarrasch commented 1 year ago

This "fixes" some weird behavior in the multiplication code. It occurred in https://s3.amazonaws.com/julialang-reports/nanosoldier/pkgeval/by_hash/8a3027b_vs_960870e/PositiveFactorizations.primary.log PositiveFactorizations.jl. In their case, they want to multiply a matrix by the transpose of a row matrix (i.e., vector-like) into a vector. "By chance", this works out from the pov of dimensions, but from the pov of types having such a mul! method is weird and we may not wish to continue to support this (https://github.com/JuliaLang/julia/pull/49521#discussion_r1178114739). To make this package (and PositiveFactorizations.jl) run smoothly across (past and upcoming) versions, I proposed to simply catch that case and reshape the output vector to a matrix. In fact, this may even turn out to be advantageous in terms of performance, because that strange method in LinearAlgebra.jl calls generic_matmatmul!, for arguments matching the following signature:

mul!(::Vector{Float64}, ::Matrix{Float64}, ::Transpose{Float64, Matrix{Float64}}, ::Bool, ::Bool)

Once we reshape the vector to a matrix, this would go down the BLAS route!

codecov[bot] commented 1 year ago

Codecov Report

Patch coverage: 100.00% and project coverage change: -3.05 :warning:

Comparison is base (f1f3d1f) 84.51% compared to head (90f676b) 81.46%.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## master #227 +/- ## ========================================== - Coverage 84.51% 81.46% -3.05% ========================================== Files 18 18 Lines 1924 1581 -343 ========================================== - Hits 1626 1288 -338 + Misses 298 293 -5 ``` | [Impacted Files](https://app.codecov.io/gh/JuliaDiff/ReverseDiff.jl/pull/227?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaDiff) | Coverage Δ | | |---|---|---| | [src/derivatives/linalg/arithmetic.jl](https://app.codecov.io/gh/JuliaDiff/ReverseDiff.jl/pull/227?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaDiff#diff-c3JjL2Rlcml2YXRpdmVzL2xpbmFsZy9hcml0aG1ldGljLmps) | `72.15% <100.00%> (+1.02%)` | :arrow_up: | ... and [16 files with indirect coverage changes](https://app.codecov.io/gh/JuliaDiff/ReverseDiff.jl/pull/227/indirect-changes?src=pr&el=tree-more&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaDiff)

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.

dkarrasch commented 1 year ago

Gentle bump.

devmotion commented 1 year ago

I didn't get why exactly this method is needed. Shouldn't the re-routing happen in LinearAlgebra, why is some special handling in ReverseDiff needed?

In any case, it seems the code is not covered by tests yet.

dkarrasch commented 1 year ago

So here's the full stacktrace. It shows that the output is allocated by ReverseDiff, and then LinearAlgebra takes what it gets:

MethodError: no method matching mul!(::Vector{Float64}, ::Matrix{Float64}, ::Transpose{Float64, Matrix{Float64}}, ::Bool, ::Bool)

  Closest candidates are:
    mul!(::StridedVector{T}, ::StridedVecOrMat{T}, !Matched::StridedVector{T}, ::Number, ::Number) where T<:Union{Float32, Float64, ComplexF64, ComplexF32}
     @ LinearAlgebra /opt/julia/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:66
    mul!(!Matched::StridedMatrix{T}, ::StridedVecOrMat{T}, ::Transpose{<:Any, <:StridedVecOrMat{T}}, ::Number, ::Number) where T<:Union{Float32, Float64, ComplexF64, ComplexF32}
     @ LinearAlgebra /opt/julia/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:369
    mul!(!Matched::StridedMatrix{T}, ::Union{Adjoint{<:Any, <:StridedVecOrMat{T}}, Transpose{<:Any, <:StridedVecOrMat{T}}, StridedMatrix{T}, StridedVector{T}}, ::Union{Adjoint{<:Any, <:StridedVecOrMat{T}}, Transpose{<:Any, <:StridedVecOrMat{T}}, StridedMatrix{T}, StridedVector{T}}, ::Number, ::Number) where T<:Union{Float32, Float64, ComplexF64, ComplexF32}
     @ LinearAlgebra /opt/julia/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:356
    ...

  Stacktrace:
    [1] mul!
      @ /opt/julia/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:251 [inlined]
    [2] reverse_mul!(output::ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}, output_deriv::Matrix{Float64}, a::ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, b::Matrix{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}, a_tmp::Vector{Float64}, b_tmp::Matrix{Float64})
      @ ReverseDiff ~/.julia/packages/ReverseDiff/Zu4v6/src/derivatives/linalg/arithmetic.jl:273
    [3] special_reverse_exec!(instruction::ReverseDiff.SpecialInstruction{typeof(*), Tuple{ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, Matrix{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}}, ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}, Tuple{Vector{Float64}, Matrix{Float64}}})
      @ ReverseDiff ~/.julia/packages/ReverseDiff/Zu4v6/src/derivatives/linalg/arithmetic.jl:265
    [4] reverse_exec!(instruction::ReverseDiff.SpecialInstruction{typeof(*), Tuple{ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, Matrix{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}}, ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}, Tuple{Vector{Float64}, Matrix{Float64}}})
      @ ReverseDiff ~/.julia/packages/ReverseDiff/Zu4v6/src/tape.jl:93
    [5] reverse_pass!(tape::Vector{ReverseDiff.AbstractInstruction})
      @ ReverseDiff ~/.julia/packages/ReverseDiff/Zu4v6/src/tape.jl:87
    [6] reverse_pass!
      @ ~/.julia/packages/ReverseDiff/Zu4v6/src/api/tape.jl:36 [inlined]
    [7] seeded_reverse_pass!(result::Vector{Float64}, output::ReverseDiff.TrackedReal{Float64, Float64, Nothing}, input::ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, tape::ReverseDiff.GradientTape{var"#8#15"{DataType, var"#5#12"}, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, ReverseDiff.TrackedReal{Float64, Float64, Nothing}})
      @ ReverseDiff ~/.julia/packages/ReverseDiff/Zu4v6/src/api/utils.jl:31
    [8] seeded_reverse_pass!(result::Vector{Float64}, t::ReverseDiff.GradientTape{var"#8#15"{DataType, var"#5#12"}, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, ReverseDiff.TrackedReal{Float64, Float64, Nothing}})
      @ ReverseDiff ~/.julia/packages/ReverseDiff/Zu4v6/src/api/tape.jl:47
    [9] gradient(f::Function, input::Vector{Float64}, cfg::ReverseDiff.GradientConfig{ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}})
      @ ReverseDiff ~/.julia/packages/ReverseDiff/Zu4v6/src/api/gradients.jl:24
   [10] gradient(f::Function, input::Vector{Float64})
      @ ReverseDiff ~/.julia/packages/ReverseDiff/Zu4v6/src/api/gradients.jl:22

I'm not familiar with AD and how to set it up so that it goes this route. As I wrote, it came up with PositiveFactorizations.jl and AD-ing cholesky or something like that, so not easily reproducible here.

In any case, it seems the code is not covered by tests yet.

True, there's quite a lot uncovered in that code area. Any help with that would be much appreciated.

dkarrasch commented 1 year ago

If you wish, there is the test in PositiveFactorizations.jl ("downstream") that tests this. Not satisfactory, but given that the reverse_mul! is already largely uncovered, maybe acceptable?

EDIT: I think I found a test case.

dkarrasch commented 1 year ago

I think it no longer does after I merged https://github.com/JuliaLang/julia/pull/49521.

devmotion commented 1 year ago

It seems that the PR introduces method ambiguity issues which cause test failures after removing collect in the tests.

Can you fix these and add also a version of norm_hermitian with collect (but keep the one without as well since it seems to cover some of the method ambiguity issues)?

If I can find some time I'll also check if the initial problem can be fixed in some other way, since evidently adding new dispatches seems a bit problematic.

dkarrasch commented 1 year ago

Interestingly, the ambiguities do not occur on v1.8, only on the very old versions. I have now avoided adding another method and introduced branches in the existing one. Those branches, however, should be compiled away because the conditions can be checked in the type domain.

BTW, this packages has more than 3k ambiguities! My computer almost crashed when running Aqua.jl on it.

devmotion commented 1 year ago

Thank you @dkarrasch!