DynareJulia / FastLapackInterface.jl

MIT License
31 stars 8 forks source link

Avoid unnecessary allocations in `resize!` by using views. #40

Open manuelbb-upb opened 4 months ago

manuelbb-upb commented 4 months ago

When looking into https://github.com/DynareJulia/FastLapackInterface.jl/issues/39 I noticed that for a QRWYWs the T matrix is re-allocated quite often, even when computing the QR factorization of a smaller matrix then specified at instantiation.

This can be avoided by using views. This is a proof-of-concept re-implementation based on what is currently in FastLapackInterface.

import LinearAlgebra: BlasInt, BlasFloat, require_one_based_indexing, chkstride1 
import LinearAlgebra.BLAS: @blasfunc
import LinearAlgebra.LAPACK: chklapackerror, liblapack

mutable struct QRWYWs{R<:Number,MT<:StridedMatrix{R}}
    work::Vector{R}
    T::MT
end

function QRWYWs(A::StridedMatrix{T}; kwargs...) where {T <: BlasFloat}
    resize!(QRWYWs(T[], Matrix{T}(undef, 0, 0)), A; kwargs...)
end

function Base.resize!(ws::QRWYWs, A::StridedMatrix; blocksize=36, work=true)
    require_one_based_indexing(A)
    chkstride1(A)
    m, n = BlasInt.(size(A))
    @assert n >= 0 ArgumentError("Not a Matrix")
    m1 = min(m, n)
    nb = min(m1, blocksize)
    ws.T = similar(ws.T, nb, m1)
    if work
        resize!(ws.work, nb*n)
    end
    return ws
end

function geqrt!(ws::QRWYWs, A::AbstractMatrix{Float64}; resize=true, blocksize=36)
    m, n = size(A)
    minmn = min(m, n)
    nb = min(minmn, blocksize)
    t1 = size(ws.T, 1)
    if t1 < nb
        if resize
            resize!(ws, A, work = true)
        else
            #throw(WorkspaceSizeError(nb, minmn))
            error("Cannot resize.")
        end
    end
    T = @view(ws.T[1:nb, 1:minmn])
    if nb > 0
        lda = max(1, stride(A, 2))
        work = ws.work
        info = Ref{BlasInt}()
        ccall(
            (@blasfunc(dgeqrt_), liblapack), Cvoid,
            (Ref{BlasInt}, Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64},
            Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt}, Ptr{Float64},
            Ptr{BlasInt}),
            m, n, nb, A,
            lda, T, max(1, stride(ws.T, 2)), work,
            info
        )
        chklapackerror(info[])
    end
    return A, T
end

At the moment, this is for Float64 only and I have not yet thought about the kwargs (blocksize). But it seems to save allocations:

import FastLapackInterface as FLA
let n=50;
   A = rand(n, n);
   ws1 = QRWYWs(A);
   ws2 = FLA.QRWYWs(A);
   for j in (0, 5, 10, 25, 50, 60)
     B = rand(n, j)
     @show size(B)
     println("new:")
     @time geqrt!(ws1, B);
     println("old:")
     @time geqrt!(ws2, B);
   end
 end

gives

size(B) = (50, 0)
new:
  0.000001 seconds
old:
  0.000003 seconds (1 allocation: 64 bytes)
size(B) = (50, 5)
new:
  0.000049 seconds
old:
  0.000006 seconds (1 allocation: 256 bytes)
size(B) = (50, 10)
new:
  0.000016 seconds
old:
  0.000015 seconds (1 allocation: 896 bytes)
size(B) = (50, 25)
new:
  0.000048 seconds
old:
  0.000038 seconds (1 allocation: 5.062 KiB)
size(B) = (50, 50)
new:
  0.000098 seconds
old:
  0.000085 seconds (1 allocation: 14.188 KiB)
size(B) = (50, 60)
new:
  0.000112 seconds
old:
  0.000090 seconds (2 allocations: 46.125 KiB)

Is this worth a PR?

MichelJuillard commented 4 months ago

Yes @manuelbb-upb , please prepare a PR. Best