JuliaMath / RealDot.jl

Compute `real(dot(x, y))` efficiently.
MIT License
4 stars 3 forks source link

more efficient realdot for arrays #6

Open stevengj opened 2 years ago

stevengj commented 2 years ago

Right now you use real(LinearAlgebra.dot(x, y)), which does twice as much work as necessary for complex arrays.

You could do e.g.

realdot(x::Vector{Complex{T}}, y::Vector{Complex{T}}) where {T} = dot(reinterpret(T, x), reinterpret(T, y))

and similar…

Or you could just write out the loop (two muladd calls per complex entry, with @simd), since BLAS doesn't have much of an advantage for this type of operation.

devmotion commented 2 years ago

I did not include it due to the discussion and benchmarks mentioned in https://github.com/JuliaLang/julia/issues/22261. However, I guess it would be good to benchmark it again since it's a quite old issue.

stevengj commented 2 years ago

I just tried a couple of options, the reinterpret trick was about 10% faster than real(dot(a,b)) on my machine for 1000-element array, but simply writing out the loops was faster yet (about 25% faster):

function realdot1(a::AbstractVector{Complex{T}}, b::AbstractVector{Complex{S}}) where {T,S}
    axes(a,1) == axes(b,1) || throw(DimensionMismatch())
    s = zero(T) * zero(S)
    @simd for i in axes(a,1)
        @inbounds ai, bi = a[i], b[i]
        s = muladd(real(ai), real(bi), s)
        s = muladd(imag(ai), imag(bi), s)
    end
    return s
end
devmotion commented 2 years ago

I benchmarked the following three implementations on my computer with this script:

using BenchmarkTools
using Plots
using LinearAlgebra

# current fallback in RealDot
real_dot(x::AbstractVector{<:Complex}, y::AbstractVector{<:Complex}) = real(dot(x, y))

# dot product of real arrays
function dot_reinterpret(x::Vector{Complex{S}}, y::Vector{Complex{T}}) where {S,T}
    return dot(reinterpret(S, x), reinterpret(T, y))
end

# `@simd` loop
function simd_muladd(a::AbstractVector{Complex{S}}, b::AbstractVector{Complex{T}}) where {S,T}
    axes(a,1) === axes(b,1) || throw(DimensionMismatch())

    s = zero(S) * zero(T)
    @simd for i in axes(a,1)
        @inbounds re_ai, im_ai = reim(a[i])
        @inbounds re_bi, im_bi = reim(b[i])
        s = muladd(re_ai, re_bi, s)
        s = muladd(im_ai, im_bi, s)
    end

    return s
end

function benchmark(n::Int)
    x = randn(ComplexF64, n)
    y = randn(ComplexF64, n)

    t1 = @belapsed real_dot($x, $y)
    t2 = @belapsed dot_reinterpret($x, $y)
    t3 = @belapsed simd_muladd($x, $y)

    return (; real_dot=t1, dot_reinterpret=t2, simd_muladd=t3)
end

function plot_benchmarks(ns; kwargs...)
    timings = Matrix{Float64}(undef, 3, length(ns))
    for (i, n) in enumerate(ns)
        ts = benchmark(n)
        timings[1, i] = ts.real_dot
        timings[2, i] = ts.dot_reinterpret
        timings[3, i] = ts.simd_muladd
    end
    plot(
        ns, timings';
        xlabel="n", ylabel="time", label=["real_dot" "dot_reinterpret" "simd_muladd"], kwargs...,
    )
end

plot_benchmarks(floor.(Int, 10 .^ range(1, 6; length=10)); xscale=:log10, yscale=:log10, legend=:outertopright)

I observe the following behaviour; realdot

It seems simd_muladd is faster for small arrays (< 500 entries) but slower for arrays with more elements. Visually, there is also no clear performance improvement by using reinterpret.

stevengj commented 2 years ago

For big arrays, it is memory-bound and not compute-bound, which is probably why they all become about the same.