invenia / Nabla.jl

A operator overloading, tape-based, reverse-mode AD
Other
68 stars 5 forks source link

Custom sensitivities for strided matmul never hit and I think are wrong #192

Open oxinabox opened 4 years ago

oxinabox commented 4 years ago

This file: https://github.com/invenia/Nabla.jl/blob/4cadc87677fb1187354999dcf93bd528f45f85d0/src/sensitivities/linalg/strided.jl

it says:

const RS = StridedMatrix{<:∇Scalar}
const RST = Transpose{<:∇Scalar, RS}
const RSA = Adjoint{<:∇Scalar, RS}

But should say

const RS = StridedMatrix{<:∇Scalar}
const RST = Transpose{<:∇Scalar, <:RS}
const RSA = Adjoint{<:∇Scalar, <:RS}

Because otherwise RST are targetting Transpose{<:∇Scalar, Union(DenseArray, ...} (Similar for RSA). Which will never occur in real code without manually costructing the Transpose

So ithink the only strided rules that are hit is: (RS, RS, 'N', 'C', :Ȳ, :B, 'C', 'N', :A, :Ȳ)

And i think the others are wrong also, because i get errors that say GEMM is being used wrong when i change them to be that.

oxinabox commented 3 years ago

will be closed by #189