Jutho / StridedViews.jl

A Julia package to represent strided views over a parent DenseArray
MIT License
6 stars 2 forks source link

Remove `DenseArray` restriction and add support for `FillArrays` #6

Open lkdvos opened 10 months ago

lkdvos commented 10 months ago

This removes the hard restriction on the parent arrays subtyping DenseArray, but only defines a default constructor for StridedView(::DenseArray). (note that the current implementation does however have a constructor StridedView(::AbstractArray, size, strides, offset, op), which is useful in order to not have to redefine methods like conj etc for non-DenseArray StridedViews.)

Additionally, it defines a package extension (which I currently have implemented solely as an extension, and not through requires for julia < v1.9) to be able to handle FillArrays, which shows up sometimes in for example Zygotes automatic differentiation rules.

This fixes #2

codecov[bot] commented 10 months ago

Codecov Report

Attention: 5 lines in your changes are missing coverage. Please review.

Comparison is base (517e676) 93.67% compared to head (5ba0421) 91.53%.

Files Patch % Lines
ext/StridedViewsFillArraysExt.jl 64.28% 5 Missing :warning:
Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #6 +/- ## ========================================== - Coverage 93.67% 91.53% -2.15% ========================================== Files 4 5 +1 Lines 174 189 +15 ========================================== + Hits 163 173 +10 - Misses 11 16 +5 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

kishore-nori commented 9 months ago

Hi @lkdvos, thank you very much for this PR, like you mentioned I came across the StridedView constructor not being available for FillArray in an AD situation (see MWE below), I tried using this branch to check if it resolves the below example, but deep down it throws another error regarding the lack of pointer function for FillArray (but only in one case below), called within unsafe_convert on a StridedView. Below is the example and the stack trace when using this branch:

using TensorOperations, StridedViews, Zygote

A = rand(4,3,2)
x = rand(2)

function f(x)
 TensorOperations.@tensor B[a,b] := A[a,b,c] * x[c]
 sum(B)
end

function g(x)
 TensorOperations.@tensor B[a,b] := A[a,b,c] * x[c]
 sum(B,dims=2)
end

Zygote.gradient(f,x) # doesn't work

Zygote.jacobian(g,x) # works and matches with ForwardDiff

# Following is the error stack trace 
julia> Zygote.jacobian(f,x)
ERROR: conversion to pointer not defined for FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] unsafe_convert(#unused#::Type{Ptr{Float64}}, a::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Base ./pointer.jl:67
  [3] pointer
    @ ./abstractarray.jl:1245 [inlined]
  [4] unsafe_convert(#unused#::Type{Ptr{Float64}}, a::StridedView{Float64, 2, FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, typeof(identity)})
    @ StridedViews ~/.julia/packages/StridedViews/cN0vi/src/stridedview.jl:191
  [5] gemm!(transA::Char, transB::Char, alpha::Float64, A::StridedView{Float64, 2, FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, typeof(identity)}, B::StridedView{Float64, 2, Vector{Float64}, typeof(identity)}, beta::Float64, C::StridedView{Float64, 2, Array{Float64, 3}, typeof(identity)})
    @ LinearAlgebra.BLAS ~/julia-src/julia-1.9.3/share/julia/stdlib/v1.9/LinearAlgebra/src/blas.jl:1524
  [6] _threaded_blas_mul!(C::StridedView{Float64, 2, Array{Float64, 3}, typeof(identity)}, A::StridedView{Float64, 2, FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, typeof(identity)}, B::StridedView{Float64, 2, Vector{Float64}, typeof(identity)}, α::VectorInterface.One, β::VectorInterface.Zero, nthreads::Int64)
    @ Strided ~/.julia/packages/Strided/l1vm3/src/linalg.jl:105
  [7] _mul!
    @ ~/.julia/packages/Strided/l1vm3/src/linalg.jl:91 [inlined]
  [8] mul!(C::StridedView{Float64, 2, Array{Float64, 3}, typeof(identity)}, A::StridedView{Float64, 2, FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, typeof(identity)}, B::StridedView{Float64, 2, Vector{Float64}, typeof(identity)}, α::VectorInterface.One, β::VectorInterface.Zero)
    @ Strided ~/.julia/packages/Strided/l1vm3/src/linalg.jl:60
  [9] _unsafe_blas_contract!(C::StridedView{Float64, 3, Array{Float64, 3}, typeof(identity)}, ipC::Tuple{Tuple{Int64, Int64}, Tuple{Int64}}, A::StridedView{Float64, 2, FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, typeof(identity)}, pA::Tuple{Tuple{Int64, Int64}, Tuple{}}, conjA::Symbol, B::StridedView{Float64, 1, Vector{Float64}, typeof(identity)}, pB::Tuple{Tuple{}, Tuple{Int64}}, conjB::Symbol, α::VectorInterface.One, β::VectorInterface.Zero)
    @ TensorOperations ~/.julia/packages/TensorOperations/LAzcX/src/implementation/strided.jl:163
 [10] blas_contract!(C::StridedView{Float64, 3, Array{Float64, 3}, typeof(identity)}, pC::Tuple{Tuple{Int64, Int64, Int64}, Tuple{}}, A::StridedView{Float64, 2, FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, typeof(identity)}, pA::Tuple{Tuple{Int64, Int64}, Tuple{}}, conjA::Symbol, B::StridedView{Float64, 1, Vector{Float64}, typeof(identity)}, pB::Tuple{Tuple{}, Tuple{Int64}}, conjB::Symbol, α::VectorInterface.One, β::VectorInterface.Zero)
    @ TensorOperations ~/.julia/packages/TensorOperations/LAzcX/src/implementation/strided.jl:137
 [11] tensorcontract!(C::StridedView{Float64, 3, Array{Float64, 3}, typeof(identity)}, pC::Tuple{Tuple{Int64, Int64, Int64}, Tuple{}}, A::StridedView{Float64, 2, FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, typeof(identity)}, pA::Tuple{Tuple{Int64, Int64}, Tuple{}}, conjA::Symbol, B::StridedView{Float64, 1, Vector{Float64}, typeof(identity)}, pB::Tuple{Tuple{}, Tuple{Int64}}, conjB::Symbol, α::VectorInterface.One, β::VectorInterface.Zero, backend::TensorOperations.Backend{:StridedBLAS})
    @ TensorOperations ~/.julia/packages/TensorOperations/LAzcX/src/implementation/strided.jl:65
 [12] tensorcontract!
    @ ~/.julia/packages/TensorOperations/LAzcX/src/implementation/abstractarray.jl:63 [inlined]
 [13] tensorcontract!
    @ ~/.julia/packages/TensorOperations/LAzcX/src/implementation/abstractarray.jl:35 [inlined]
 [14] #62
    @ ~/.julia/packages/TensorOperations/LAzcX/ext/TensorOperationsChainRulesCoreExt.jl:99 [inlined]
 [15] unthunk
    @ ~/.julia/packages/ChainRulesCore/0t04l/src/tangent_types/thunks.jl:204 [inlined]
 [16] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/XJ8pP/src/compiler/chainrules.jl:110 [inlined]
 [17] map (repeats 4 times)
    @ ./tuple.jl:276 [inlined]
 [18] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/XJ8pP/src/compiler/chainrules.jl:111 [inlined]
 [19] ZBack
    @ ~/.julia/packages/Zygote/XJ8pP/src/compiler/chainrules.jl:211 [inlined]
 [20] Pullback
    @ ./REPL[54]:2 [inlined]
 [21] (::Zygote.Pullback{Tuple{typeof(f), Vector{Float64}}, Tuple{Zygote.var"#1990#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Array{Float64, 3}}}, Zygote.Pullback{Tuple{typeof(scalartype), Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}, Zygote.Pullback{Tuple{typeof(scalartype), Type{Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Float64}}, Tuple{}}}}}}, Zygote.var"#1990#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Array{Float64, 3}}}, Zygote.Pullback{Tuple{typeof(TensorOperations.promote_contract), Type{Float64}, Type{Float64}}, Tuple{Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.Pullback{Tuple{typeof(Base.promote_op), typeof(TensorOperations.tensorop), Type{Float64}, Type{Float64}}, Tuple{Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.ZBack{ChainRules.var"#apply_type_pullback#42"{Tuple{DataType, DataType}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(Core.Compiler.return_type), typeof(TensorOperations.tensorop), Type{Tuple{Float64, Float64}}}, Tuple{typeof(Core.Compiler.return_type)}}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, Zygote.Pullback{Tuple{typeof(scalartype), Array{Float64, 3}}, Tuple{Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}, Zygote.Pullback{Tuple{typeof(scalartype), Type{Array{Float64, 3}}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Float64}}, Tuple{}}}}}}, Zygote.var"#3027#back#782"{Zygote.var"#776#780"{Matrix{Float64}}}, Zygote.var"#1990#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Array{Float64, 3}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#tensoralloc_contract_pullback#41"{Tuple{DataType, Tuple{Tuple{Int64, Int64}, Tuple{}}, Array{Float64, 3}, Tuple{Tuple{Int64, Int64}, Tuple{Int64}}, Symbol, Vector{Float64}, Tuple{Tuple{Int64}, Tuple{}}, Symbol, Bool}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#pullback#67"{Matrix{Float64}, Tuple{Tuple{Int64, Int64}, Tuple{}}, Array{Float64, 3}, Tuple{Tuple{Int64, Int64}, Tuple{Int64}}, Symbol, Vector{Float64}, Tuple{Tuple{Int64}, Tuple{}}, Symbol, VectorInterface.One, VectorInterface.Zero, Tuple{}, ChainRulesCore.ProjectTo{Number, NamedTuple{(), Tuple{}}}, ChainRulesCore.ProjectTo{Number, NamedTuple{(), Tuple{}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface2.jl:0
 [22] #291
    @ ~/.julia/packages/Zygote/XJ8pP/src/lib/lib.jl:206 [inlined]
 [23] (::Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}}, Zygote.Pullback{Tuple{typeof(f), Vector{Float64}}, Tuple{Zygote.var"#1990#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Array{Float64, 3}}}, Zygote.Pullback{Tuple{typeof(scalartype), Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}, Zygote.Pullback{Tuple{typeof(scalartype), Type{Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Float64}}, Tuple{}}}}}}, Zygote.var"#1990#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Array{Float64, 3}}}, Zygote.Pullback{Tuple{typeof(TensorOperations.promote_contract), Type{Float64}, Type{Float64}}, Tuple{Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.Pullback{Tuple{typeof(Base.promote_op), typeof(TensorOperations.tensorop), Type{Float64}, Type{Float64}}, Tuple{Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.ZBack{ChainRules.var"#apply_type_pullback#42"{Tuple{DataType, DataType}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(Core.Compiler.return_type), typeof(TensorOperations.tensorop), Type{Tuple{Float64, Float64}}}, Tuple{typeof(Core.Compiler.return_type)}}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, Zygote.Pullback{Tuple{typeof(scalartype), Array{Float64, 3}}, Tuple{Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}, Zygote.Pullback{Tuple{typeof(scalartype), Type{Array{Float64, 3}}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Float64}}, Tuple{}}}}}}, Zygote.var"#3027#back#782"{Zygote.var"#776#780"{Matrix{Float64}}}, Zygote.var"#1990#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Array{Float64, 3}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#tensoralloc_contract_pullback#41"{Tuple{DataType, Tuple{Tuple{Int64, Int64}, Tuple{}}, Array{Float64, 3}, Tuple{Tuple{Int64, Int64}, Tuple{Int64}}, Symbol, Vector{Float64}, Tuple{Tuple{Int64}, Tuple{}}, Symbol, Bool}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#pullback#67"{Matrix{Float64}, Tuple{Tuple{Int64, Int64}, Tuple{}}, Array{Float64, 3}, Tuple{Tuple{Int64, Int64}, Tuple{Int64}}, Symbol, Vector{Float64}, Tuple{Tuple{Int64}, Tuple{}}, Symbol, VectorInterface.One, VectorInterface.Zero, Tuple{}, ChainRulesCore.ProjectTo{Number, NamedTuple{(), Tuple{}}}, ChainRulesCore.ProjectTo{Number, NamedTuple{(), Tuple{}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}}}}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71
 [24] Pullback
    @ ./operators.jl:1035 [inlined]
 [25] (::Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{typeof(f)}, Tuple{Vector{Float64}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Any})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface2.jl:0
 [26] Pullback
    @ ./operators.jl:1034 [inlined]
 [27] Pullback
    @ ./operators.jl:1031 [inlined]
 [28] (::Zygote.Pullback{Tuple{Base.var"##_#97", Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ComposedFunction{typeof(Zygote._jvec), typeof(f)}, Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), ComposedFunction{typeof(Zygote._jvec), typeof(f)}}, Tuple{Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:outer, Zygote.Context{false}, ComposedFunction{typeof(Zygote._jvec), typeof(f)}, typeof(Zygote._jvec)}}, Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:inner, Zygote.Context{false}, ComposedFunction{typeof(Zygote._jvec), typeof(f)}, typeof(f)}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), typeof(f)}, Tuple{Zygote.Pullback{Tuple{typeof(Base.maybeconstructor), typeof(f)}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), typeof(Zygote._jvec)}, Tuple{Zygote.Pullback{Tuple{typeof(Base.maybeconstructor), typeof(Zygote._jvec)}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing}}, Zygote.var"#2017#back#204"{typeof(identity)}}}}}, Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{typeof(Zygote._jvec), typeof(f)}, Tuple{Vector{Float64}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Tuple{Zygote.var"#2033#back#213"{Zygote.var"#back#211"{2, 1, Zygote.Context{false}, typeof(Zygote._jvec)}}, Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{typeof(f)}, Tuple{Vector{Float64}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Any}, Zygote.var"#2145#back#281"{Zygote.var"#277#280"}, Zygote.Pullback{Tuple{typeof(Zygote._jvec), Float64}, Tuple{Zygote.ZBack{ChainRules.var"#vcat_pullback#1416"{Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}, Tuple{Tuple{}}, Val{1}}}, Zygote.Pullback{Tuple{typeof(Zygote._jvec), Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(vec), Vector{Float64}}, Tuple{}}}}}}}}}})(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface2.jl:0
 [29] #291
    @ ~/.julia/packages/Zygote/XJ8pP/src/lib/lib.jl:206 [inlined]
 [30] #2173#back
    @ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71 [inlined]
 [31] Pullback
    @ ./operators.jl:1031 [inlined]
 [32] (::Zygote.Pullback{Tuple{ComposedFunction{typeof(Zygote._jvec), typeof(f)}, Vector{Float64}}, Tuple{Zygote.var"#2370#back#419"{Zygote.var"#pairs_namedtuple_pullback#418"{(), NamedTuple{(), Tuple{}}}}, Zygote.Pullback{Tuple{Type{NamedTuple}}, Tuple{}}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing}}, Zygote.Pullback{Tuple{Base.var"##_#97", Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ComposedFunction{typeof(Zygote._jvec), typeof(f)}, Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), ComposedFunction{typeof(Zygote._jvec), typeof(f)}}, Tuple{Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:outer, Zygote.Context{false}, ComposedFunction{typeof(Zygote._jvec), typeof(f)}, typeof(Zygote._jvec)}}, Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:inner, Zygote.Context{false}, ComposedFunction{typeof(Zygote._jvec), typeof(f)}, typeof(f)}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), typeof(f)}, Tuple{Zygote.Pullback{Tuple{typeof(Base.maybeconstructor), typeof(f)}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), typeof(Zygote._jvec)}, Tuple{Zygote.Pullback{Tuple{typeof(Base.maybeconstructor), typeof(Zygote._jvec)}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing}}, Zygote.var"#2017#back#204"{typeof(identity)}}}}}, Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{typeof(Zygote._jvec), typeof(f)}, Tuple{Vector{Float64}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Tuple{Zygote.var"#2033#back#213"{Zygote.var"#back#211"{2, 1, Zygote.Context{false}, typeof(Zygote._jvec)}}, Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{typeof(f)}, Tuple{Vector{Float64}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Any}, Zygote.var"#2145#back#281"{Zygote.var"#277#280"}, Zygote.Pullback{Tuple{typeof(Zygote._jvec), Float64}, Tuple{Zygote.ZBack{ChainRules.var"#vcat_pullback#1416"{Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}, Tuple{Tuple{}}, Val{1}}}, Zygote.Pullback{Tuple{typeof(Zygote._jvec), Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(vec), Vector{Float64}}, Tuple{}}}}}}}}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}}})(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface2.jl:0
 [33] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{ComposedFunction{typeof(Zygote._jvec), typeof(f)}, Vector{Float64}}, Tuple{Zygote.var"#2370#back#419"{Zygote.var"#pairs_namedtuple_pullback#418"{(), NamedTuple{(), Tuple{}}}}, Zygote.Pullback{Tuple{Type{NamedTuple}}, Tuple{}}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing}}, Zygote.Pullback{Tuple{Base.var"##_#97", Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ComposedFunction{typeof(Zygote._jvec), typeof(f)}, Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), ComposedFunction{typeof(Zygote._jvec), typeof(f)}}, Tuple{Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:outer, Zygote.Context{false}, ComposedFunction{typeof(Zygote._jvec), typeof(f)}, typeof(Zygote._jvec)}}, Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:inner, Zygote.Context{false}, ComposedFunction{typeof(Zygote._jvec), typeof(f)}, typeof(f)}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), typeof(f)}, Tuple{Zygote.Pullback{Tuple{typeof(Base.maybeconstructor), typeof(f)}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), typeof(Zygote._jvec)}, Tuple{Zygote.Pullback{Tuple{typeof(Base.maybeconstructor), typeof(Zygote._jvec)}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing}}, Zygote.var"#2017#back#204"{typeof(identity)}}}}}, Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{typeof(Zygote._jvec), typeof(f)}, Tuple{Vector{Float64}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Tuple{Zygote.var"#2033#back#213"{Zygote.var"#back#211"{2, 1, Zygote.Context{false}, typeof(Zygote._jvec)}}, Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{typeof(f)}, Tuple{Vector{Float64}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Any}, Zygote.var"#2145#back#281"{Zygote.var"#277#280"}, Zygote.Pullback{Tuple{typeof(Zygote._jvec), Float64}, Tuple{Zygote.ZBack{ChainRules.var"#vcat_pullback#1416"{Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}, Tuple{Tuple{}}, Val{1}}}, Zygote.Pullback{Tuple{typeof(Zygote._jvec), Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(vec), Vector{Float64}}, Tuple{}}}}}}}}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}}}})(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface.jl:45
 [34] withjacobian(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/XJ8pP/src/lib/grad.jl:150
 [35] jacobian(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/XJ8pP/src/lib/grad.jl:128
 [36] top-level scope
    @ REPL[63]:1

I thought it would be relevant to this PR so posting here, but if requires a new issue, let me know, thank you.

lkdvos commented 9 months ago

Thanks for reporting this! I think my PR indeed requires a bit more work, as multiplication should not be dispatched through to BLAS like that, and your example is the exact reason I started looking into this. I'll try and make some time to look further into it next week, and I hope to add this as a test case to the TensorOperations suite.

kishore-nori commented 9 months ago

Great, thank you very much, that would be very helpful.