Open MasonProtter opened 4 years ago
Your second and fourth examples are silently getting wrong answers, which is bad. And the third example shouldn't be throwing "op
not defined" errors, so that's something else for me to look into.
Those probably should have been "reduction not found" errors.
LLVM evaluates the loop serially:
julia> function llvm_test(a,b,h,v)
@inbounds @simd for i in eachindex(v)
h = a * v[i] + b * h
end;h
end
llvm_test (generic function with 1 method)
julia> @code_llvm debuginfo=:none llvm_test(1.0,1.0,3.0,Float64[])
define double @julia_llvm_test_2311(double, double, double, %jl_value_t* nonnull align 16 dereferenceable(40)) {
top:
%4 = bitcast %jl_value_t* %3 to %jl_value_t**
%5 = getelementptr inbounds %jl_value_t*, %jl_value_t** %4, i64 3
%6 = bitcast %jl_value_t** %5 to i64*
%7 = load i64, i64* %6, align 8
%8 = icmp sgt i64 %7, 0
%9 = select i1 %8, i64 %7, i64 0
br i1 %8, label %L13.lr.ph, label %L31
L13.lr.ph: ; preds = %top
%10 = bitcast %jl_value_t* %3 to double**
%11 = load double*, double** %10, align 8
br label %L13
L13: ; preds = %L13, %L13.lr.ph
%value_phi16 = phi double [ %2, %L13.lr.ph ], [ %15, %L13 ]
%value_phi5 = phi i64 [ 0, %L13.lr.ph ], [ %16, %L13 ]
%12 = getelementptr inbounds double, double* %11, i64 %value_phi5
%13 = load double, double* %12, align 8
%14 = fmul double %value_phi16, %1
%15 = fadd double %14, %13
%16 = add nuw nsw i64 %value_phi5, 1
%17 = icmp ult i64 %16, %9
br i1 %17, label %L13, label %L31
L31: ; preds = %L13, %top
%value_phi2 = phi double [ %2, %top ], [ %15, %L13 ]
ret double %value_phi2
}
It's not immediately obvious how to SIMD this reduction.
The easiest way to SIMD reductions is to just do multiple of them in parallel. But because we're only calculating one h
, that obviously wont work.
Normally, LoopVectorization does reductions by trying to classify a reduction as one of:
zero
, sum
to combine them.one
, prod
to combine them.type_min
, max
to combine them.type_max
, min
to combine them.Importantly, each of these types of reductions is associative, so that we don't need to accumulate in order. Given 16 accumulators (e.g., 4x unrolling with SIMD width of 4), the first accumulator would accumulate the 1st, 17th, 33rd... iterations of the loop.
The problem is, the loop as written doesn't fall into any of those categories.
h_1 = v[1] + b * h_0
h_2 = v[2] + b * h_1
h_3 = v[3] + b * h_2
h_4 = v[4] + b * h_3
However, with a bit of algebra...
h_4 = v[4] + b * (v[3] + b * (v[2] + b * (v[1] + b * h_0)))
h_4 = v[4] + (v[3] * b + (v[2] * b^2 + (v[1] * b^3 + b^4 * h_0)))
h_4 = (v[4] * b^-4 + (v[3] * b^-3 + (v[2] * b^-2 + (v[1] * b^-1 + h_0)))) * b^4
...it turns out this was secretly an additive reduction all along. Meaning it should be transformed into something like this
function test(a,b,h,v)
binv = inv(b)
s = zero(h)
for i in eachindex(v)
s += a * v[i] * binv ^ i
end
(s + h) * b^length(v)
end
This doesn't SIMD with LLVM (even with @simd
and @fastmath
), and it throws an error with LoopVectorization (^(::whatever, ::_MM)
not defined), so that's a method that needs adding.
I'm currently rewriting VectorizationBase and a few other libraries (including parts of LoopVectorization), so it'll be a while until I can make these sorts of changes.
The thing we'd have to be sure of for performance is that binv^i
gets calculated via initializing a vector of <b^-1, b^-2,...b^-A>
where A
is the number of accumulators (e.g., 8), and then update it on each iteration by multiplying the vector by b^-A
.
I'd be worried about accuracy for long vectors. Multiplying by b^length(v)
is a sign that things are likely to end badly unless b
is very close to 1
.
That said, the serial version will probably blow up for deviant b
as well.
And if we handle b^i
in that manner, the SIMD version should be really fast -- much faster than serial, which would it a really impressive benchmark for "look what this can manage to vectorize"!
Given your experience with SymbolicUtils.jl
, have any suggestions approaches for automating these sorts of transformations?
That is, if a reduction doesn't fall into any of the 4 categories, for trying to figure out if we can transform the expression so that it does fall into one of them?
Given your experience with SymbolicUtils.jl, have any suggestions approaches for automating these sorts of transformations? That is, if a reduction doesn't fall into any of the 4 categories, for trying to figure out if we can transform the expression so that it does fall into one of them?
Hm. Interesting question. It seems like a hard problem because what you're doing by hand is kinda like a partial loop-unrolling and then re-arranging. Certainly possible, but tricky to do automatically. An interesting application though for sure.
Something is causing this third case to be unhappy:
Furthermore, if I try to fix it by changing around the syntax, it actually runs and gets an incorrect result: