JuliaLabs / Poly.jl

MIT License
16 stars 3 forks source link

Wrong answer in triangular solve #12

Open chriselrod opened 3 years ago

chriselrod commented 3 years ago
julia> using Poly, LinearAlgebra

julia> function solve_v2!(A, B, U)
         M, N = size(B)
         @inbounds @fastmath for m in 1:M
           for n in 1:N; A[m,n] = B[m,n]; end;
           for k in 1:N
             A[m,k] /= U[k,k]
             for n in k+1:N
               A[m,n] -= A[m,k] * U[k,n]
             end
           end
         end
         return A
       end
solve_v2! (generic function with 1 method)

julia> function solve_v2_poly!(A, B, U)
         M, N = size(B)
         @poly_loop for m in 1:M
           for n2 in 1:N; A[m,n2] = B[m,n2]; end;
           for k in 1:N
             A[m,k] /= U[k,k]
             for n in k+1:N
               A[m,n] -= A[m,k] * U[k,n]
             end
           end
         end
         return A
       end
solve_v2_poly! (generic function with 1 method)

julia> N = 128;

julia> B = rand(N,N); U = copy(UpperTriangular(rand(N,N)));

julia> A = similar(B);

julia> Ac = B / UpperTriangular(U);

julia> solve_v2!(fill!(A, NaN), B, U) ≈ Ac
true

julia> solve_v2_poly!(fill!(A, NaN), B, U) ≈ Ac
false
chriselrod commented 3 years ago

There is a problem in tiling:

julia> solve_v2_poly!(fill!(A, NaN), B, U);

julia> A[:,1:64] ≈ Ac[:,1:64]
true

julia> A[:,65:end] .- Ac[:,65:end]
128×64 Matrix{Float64}:
  1.51032e13  -6.57644e12  -3.07688e12  -7.07514e12   2.77733e12  …  -7.86782e18  -7.9127e20    4.39274e20   9.48798e20  -1.69922e21
  4.06369e13  -1.76946e13  -8.27867e12  -1.90364e13   7.47272e12     -2.11692e19  -2.129e21     1.18191e21   2.55285e21  -4.57193e21
  4.38906e13  -1.91114e13  -8.94153e12  -2.05606e13   8.07105e12     -2.28642e19  -2.29946e21   1.27655e21   2.75725e21  -4.93799e21
  2.82469e13  -1.22996e13  -5.75454e12  -1.32323e13   5.19432e12     -1.47148e19  -1.47988e21   8.21553e20   1.77449e21  -3.17797e21
 -1.44681e13   6.29988e12   2.94749e12   6.77761e12  -2.66054e12      7.53696e18   7.57995e20  -4.20801e20  -9.08898e20   1.62776e21
  1.78873e13  -7.78869e12  -3.64405e12  -8.37932e12   3.28929e12  …  -9.31813e18  -9.37128e20   5.20246e20   1.12369e21  -2.01244e21
  7.84921e13  -3.4178e13   -1.59906e13  -3.67698e13   1.44339e13     -4.08894e19  -4.11226e21   2.28292e21   4.93094e21  -8.8309e21
  5.50112e13  -2.39537e13  -1.12071e13  -2.57701e13   1.0116e13      -2.86574e19  -2.88208e21   1.59999e21   3.45585e21  -6.18914e21
  2.18036e13  -9.494e12    -4.4419e12   -1.02139e13   4.00947e12     -1.13583e19  -1.14231e21   6.34152e20   1.36972e21  -2.45306e21
  3.58632e13  -1.5616e13   -7.30617e12  -1.68002e13   6.59489e12     -1.86825e19  -1.8789e21    1.04307e21   2.25296e21  -4.03486e21
 -1.08412e12   4.7206e11    2.2086e11    5.07857e11  -1.99359e11  …   5.64757e17   5.67978e19  -3.15313e19  -6.81052e19   1.21971e20
  3.21444e13  -1.39967e13  -6.54856e12  -1.50581e13   5.91104e12     -1.67452e19  -1.68407e21   9.34913e20   2.01934e21  -3.61647e21
  9.60195e13  -4.181e13    -1.95614e13  -4.49805e13   1.7657e13      -5.00201e19  -5.03054e21   2.7927e21    6.03203e21  -1.08028e22
  ⋮                                                               ⋱                ⋮
 -1.66741e12   7.26047e11   3.39691e11   7.81104e11  -3.06621e11      8.68617e17   8.73572e19  -4.84963e19  -1.04748e20   1.87596e20
  8.41113e13  -3.66248e13  -1.71354e13  -3.94021e13   1.54672e13     -4.38166e19  -4.40666e21   2.44636e21   5.28394e21  -9.46309e21
  1.73976e13  -7.57548e12  -3.54429e12  -8.14994e12   3.19924e12     -9.06304e18  -9.11474e20   5.06005e20   1.09293e21  -1.95735e21
  1.3966e13   -6.08124e12  -2.84519e12  -6.54239e12   2.5682e12      -7.27539e18  -7.31689e20   4.06197e20   8.77355e20  -1.57127e21
  6.66503e13  -2.90217e13  -1.35782e13  -3.12225e13   1.22563e13  …  -3.47206e19  -3.49186e21   1.93851e21   4.18703e21  -7.49862e21
 -4.16109e13   1.81187e13   8.4771e12    1.94927e13  -7.65183e12      2.16767e19   2.18003e21  -1.21024e21  -2.61403e21   4.68151e21
  5.90203e12  -2.56994e12  -1.20238e12  -2.76482e12   1.08533e12     -3.07459e18  -3.09212e20   1.71659e20   3.70771e20  -6.64019e20
  7.27321e13  -3.16699e13  -1.48172e13  -3.40715e13   1.33747e13     -3.78888e19  -3.8105e21    2.1154e21    4.5691e21   -8.18287e21
  6.18825e13  -2.69456e13  -1.26069e13  -2.8989e13    1.13796e13     -3.22369e19  -3.24207e21   1.79984e21   3.88751e21  -6.9622e21
  4.26236e13  -1.85597e13  -8.68342e12  -1.99671e13   7.83806e12  …  -2.22042e19  -2.23309e21   1.2397e21    2.67765e21  -4.79545e21
  8.07631e12  -3.51669e12  -1.64533e12  -3.78337e12   1.48515e12     -4.20725e18  -4.23125e20   2.34898e20   5.07361e20  -9.08641e20
  9.41189e13  -4.09824e13  -1.91742e13  -4.40902e13   1.73075e13     -4.903e19    -4.93097e21   2.73742e21   5.91263e21  -1.0589e22
chriselrod commented 3 years ago
julia> @macroexpand @poly_loop for m in 1:M
                  for n2 in 1:N; A[m,n2] = B[m,n2]; end;
                  for k in 1:N
                    A[m,k] /= U[k,k]
                    for n in k+1:N
                      A[m,n] -= A[m,k] * U[k,n]
                    end
                  end
                end
quote
    #= /home/chriselrod/.julia/dev/Poly/src/macros.jl:414 =#
    for c0 = 0:64:N
        for c1 = max(1, c0 - 63):64:N
            for c2 = 1:64:M
                if c0 == 0
                    for c4 = c1:min(N, c1 + 63)
                        for c5 = c2:min(M, c2 + 63)
                            $(Expr(:inbounds, true))
                            local var"#320#val" = (A[c5, c4] = B[c5, c4])
                            $(Expr(:inbounds, :pop))
                            var"#320#val"
                        end
                    end
                end
                for c3 = max(1, c0):min(N, c0 + 63)
                    if c3 >= c1
                        for c5 = c2:min(M, c2 + 63)
                            $(Expr(:inbounds, true))
                            local var"#321#val" = (A[c5, c3] /= U[c3, c3])
                            $(Expr(:inbounds, :pop))
                            var"#321#val"
                        end
                    end
                    for c4 = max(c1, c3 + 1):min(N, c1 + 63)
                        for c5 = c2:min(M, c2 + 63)
                            $(Expr(:inbounds, true))
                            local var"#322#val" = (A[c5, c4] -= A[c5, c3] * U[c3, c4])
                            $(Expr(:inbounds, :pop))
                            var"#322#val"
                        end
                    end
                end
            end
        end
    end
end

Minimal change to make this correct is to replace the condition:

if c3 >= c1
    ....
end

with

if c1 <= c3 <= c1 + 63
    ....
end
using LinearAlgebra, Test
function solve_test!(A, B, U)
    M, N = size(B);
    @assert (M,N) == size(A)
    @assert N == LinearAlgebra.checksquare(U)
    for c0 = 0:64:N
        for c1 = max(1, c0 - 63):64:N
            for c2 = 1:64:M
                if c0 == 0
                    for c4 = c1:min(N, c1 + 63)
                        for c5 = c2:min(M, c2 + 63)
                            A[c5, c4] = B[c5, c4]
                        end
                    end
                end
                for c3 = max(1, c0):min(N, c0 + 63)
                    if c1 <= c3 <= c1 + 63 # modified condition
                        for c5 = c2:min(M, c2 + 63)
                            A[c5, c3] /= U[c3, c3]
                        end
                    end
                    for c4 = max(c1, c3 + 1):min(N, c1 + 63)
                        for c5 = c2:min(M, c2 + 63)
                            A[c5, c4] -= A[c5, c3] * U[c3, c4]
                        end
                    end
                end
            end
        end
    end
    return A
end

@testset "Solve Tests" begin; for n in 1:300
     B = rand(72,n); U = UpperTriangular(rand(n,n));
     @test B / U ≈ solve_test!(similar(B), B, U)
end; end

I get:

julia> @testset "Solve Tests" begin; for n in 1:300
            B = rand(72,n); U = UpperTriangular(rand(n,n));
            @test B / U ≈ solve_test!(similar(B), B, U)
       end; end
Test Summary: | Pass  Total
Solve Tests   |  300    300
Test.DefaultTestSet("Solve Tests", Any[], 300, false, false)