Closed CheukHinHoJerry closed 1 year ago
thanks for spotting this. It is odd. BLAS has operations for transposed matrices - so why aren't they used?
Do matrix operations on PtrArrays
not use BLAS, but one of the new pure Julia implementations?
My inclination is to try this first:
mul!(unwrap(out), unwrap(X), transpose(PtrArray(ps.W)))
Some quick tests:
using Polynomials4ML
using LuxCore
using StrideArrays
using ObjectPools
using BenchmarkTools
using Random
P4ML = Polynomials4ML
in_d, out_d = 4, 3 # feature dimensions
N = 10 # batch size
# set up
l = P4ML.LinearLayer(in_d, out_d; feature_first = false)
ps, st = LuxCore.setup(MersenneTwister(1234), l)
# assume the input is a PtrArray
X = randn(N, in_d)
X = PtrArray(X)
# acquire for inplace mul!
out = acquire!(st.pool, :bA, (size(X, 1), l.out_dim), eltype(X));
# current
@btime mul!(unwrap($out), unwrap($X), transpose($ps.W))
# new
@btime mul!(unwrap($out), unwrap($X), transpose(PtrArray($ps.W)))
@btime mul!(transpose(unwrap($out)), $ps.W, transpose(unwrap($X)))
533.353 ns (2 allocations: 20.50 KiB)
26.916 ns (0 allocations: 0 bytes)
28.031 ns (0 allocations: 0 bytes)
with larger matrix size:
694.274 μs (3 allocations: 30.75 KiB)
34.621 μs (0 allocations: 0 bytes)
34.642 μs (0 allocations: 0 bytes)
Not sure where the extra allocation comes from but that's not important. I guess I incline to
mul!(unwrap(out), unwrap(X), transpose(PtrArray(ps.W)))
too. I will create a PR to fix it later.
thanks for the careful testing.
closed by #63 ?
Currently this is used in
LinearLayer
:which
mul!
seems to be callinggeneric_matmul
becausetranpose(ps.W)
is atranspose
, which is slow. (Actually in the previous benchmark I didn't observe this for some reason or I just missed it?)So there are two ways that I can think of:
This wrap the weight matrix
ps.W
as aPtrArray
, and then transpose returns aPtrArray
and then we make sure we are callingmatmul!
fromBLAS
Or we hope the input
X
is aCachedArray/TmpArray/PtrArray
, and then write to the transpose of thePtrArray
ofout
so thatout
return the correct thing: