ACEsuit / Polynomials4ML.jl

Polynomials for ML: fast evaluation, batching, differentiation
MIT License
12 stars 5 forks source link

`mul!` in LinearLayer is calling a "slow" function #62

Closed CheukHinHoJerry closed 1 year ago

CheukHinHoJerry commented 1 year ago

Currently this is used in LinearLayer:

mul!(unwrap(out), unwrap(x), transpose(ps.W)); 

which mul! seems to be calling generic_matmul because tranpose(ps.W) is a transpose, which is slow. (Actually in the previous benchmark I didn't observe this for some reason or I just missed it?)

So there are two ways that I can think of:

This wrap the weight matrix ps.W as a PtrArray, and then transpose returns a PtrArray and then we make sure we are calling matmul! from BLAS

mul!(unwrap(out), unwrap(X), transpose(PtrArray(ps.W)))

Or we hope the input X is a CachedArray/TmpArray/PtrArray, and then write to the transpose of the PtrArray of out so that out return the correct thing:

mul!(transpose(unwrap(out)), ps.W, transpose(unwrap(X)))
cortner commented 1 year ago

thanks for spotting this. It is odd. BLAS has operations for transposed matrices - so why aren't they used?

Do matrix operations on PtrArrays not use BLAS, but one of the new pure Julia implementations?

My inclination is to try this first:

mul!(unwrap(out), unwrap(X), transpose(PtrArray(ps.W)))
CheukHinHoJerry commented 1 year ago

Some quick tests:

using Polynomials4ML
using LuxCore
using StrideArrays
using ObjectPools
using BenchmarkTools
using Random

P4ML = Polynomials4ML

in_d, out_d = 4, 3 # feature dimensions
N = 10 # batch size

# set up
l = P4ML.LinearLayer(in_d, out_d; feature_first = false)
ps, st = LuxCore.setup(MersenneTwister(1234), l)

# assume the input is a PtrArray
X = randn(N, in_d)
X = PtrArray(X)

# acquire for inplace mul!
out = acquire!(st.pool, :bA, (size(X, 1), l.out_dim), eltype(X)); 

# current
@btime mul!(unwrap($out), unwrap($X), transpose($ps.W))

# new
@btime mul!(unwrap($out), unwrap($X), transpose(PtrArray($ps.W)))
@btime mul!(transpose(unwrap($out)), $ps.W, transpose(unwrap($X)))
  533.353 ns (2 allocations: 20.50 KiB)
  26.916 ns (0 allocations: 0 bytes)
  28.031 ns (0 allocations: 0 bytes)

with larger matrix size:

 694.274 μs (3 allocations: 30.75 KiB)
  34.621 μs (0 allocations: 0 bytes)
  34.642 μs (0 allocations: 0 bytes)

Not sure where the extra allocation comes from but that's not important. I guess I incline to

mul!(unwrap(out), unwrap(X), transpose(PtrArray(ps.W)))

too. I will create a PR to fix it later.

cortner commented 1 year ago

thanks for the careful testing.

cortner commented 1 year ago

closed by #63 ?