dzhang314 / MultiFloats.jl

Fast, SIMD-accelerated extended-precision arithmetic for Julia
MIT License
75 stars 10 forks source link

Denormalization causes loss of accuracy in `*` #42

Open haampie opened 7 months ago

haampie commented 7 months ago

After running the QR algorithm for a bunch of iterations, I'm hitting 99% of the cases numbers like these, which lose precision when multiplied:

julia> x = Float64x4((-1.4426767353575964e-39, -4.620737599002311e-57, 1.1532128077549776e-66, -9.883310571778616e-83))

julia> y = big(x)

julia> x * x - y * y
-1.329884630995506956660201769041660088331658728714499508848494315565488248944506e-132

julia> abs(x) * abs(x) - y * y
-5.683388719909582042824734830998315327342882525094367086266805748619453533154653e-144

Relative errors:

julia> (x * x - y * y) / (y * y)
-6.389632939012188407781360510046084336131978165576939912207721601547110752468559e-55

julia> (abs(x) * abs(x) - y * y) / (y * y)
-2.730670535139620858770944035756030902664721599555555738618279743956037242156906e-66

I think they're not normalized.


Example. Input matrix:

julia> H1
6×6 Matrix{MultiFloat{Float64, 4}}:
  1.98795  -0.354736   0.0797913   0.148136     -0.164454   0.369633
 -0.19463   0.598573   0.429203   -0.486762      0.239162   0.00142155
  0.0       0.0290679  0.556247   -0.562008     -0.385439  -0.214523
  0.0       0.0        0.511856    1.12336      -0.602872  -0.206262
  0.0       0.0        0.0        -1.49082e-25   1.38356   -0.937583
  0.0       0.0        0.0         0.0           0.022335   1.20556

Apply a double shift of the QR algorithm, and move the "bulge" to the last two rows, it ends up looking like this:

julia> H2
6×6 Matrix{MultiFloat{Float64, 4}}:
 1.9865    0.358077    0.115383   0.13357      -0.182391    0.361471
 0.200765  0.554469    0.581365  -0.204676     -0.249824   -0.0259713
 0.0       0.0246946   1.0663     0.734426     -0.377182   -0.11517
 0.0       0.0        -0.359419   0.658852     -0.582584   -0.315445
 0.0       0.0         0.0       -1.83218e-50   1.43778    -0.923614
 0.0       0.0         0.0       -3.00836e-51   0.0363045   1.15134

Here accuracy is still fine:

julia> H * Q - Q * H2 |> norm
5.414815616357242382449125468599721608592507512793183811901540346944e-64

julia> Q'Q-I |> norm
4.416391581658617777192079003247210327007965703744734256761335112015e-64

But the last Given's rotation that zeros out the -3.0e-51 value is completely inaccurate:

julia> H3
6×6 Matrix{MultiFloat{Float64, 4}}:
 1.9865    0.358077    0.115383   0.13357      -0.121414    0.386247
 0.200765  0.554469    0.581365  -0.204676     -0.250731    0.0148499
 0.0       0.0246946   1.0663     0.734426     -0.390859   -0.052535
 0.0       0.0        -0.359419   0.658852     -0.625997   -0.216883
 0.0       0.0         0.0       -1.85672e-50   1.28839    -0.946116
 0.0       0.0         0.0        0.0           0.0138023   1.30073

julia> Q'Q-I |> norm
1.1366313071582258344782136697615383422393886966249503201836254869903e-47

julia> H * Q - Q * H3 |> norm
1.656359278104309870083428882117651148392957207448309578999648077939e-47

That's probably because the numbers of which a rotation is computed are not normalized:

julia> H2 .+ 0.0 .=== H2
6×6 BitMatrix:
 1  1  1  1  1  1
 1  1  1  1  1  1
 1  1  1  1  1  1
 1  1  1  1  1  1
 1  1  1  0  1  1
 1  1  1  0  1  1

If I normalize them "by hand", it looks like accuracy is restored:

julia> c, s, = LinearAlgebra.givensAlgorithm(a, b)
(0.9867864895853341289209073557253047327730824456644072969143861256782, 0.16202599782705631811014484204531232306332276631781213554170381102448, -1.85671644730184222985716651733286367867540289938089541635541279028e-50)

julia> c*c+s*s-1
8.037197050005110683132676344698446e-48

julia> c, s, = LinearAlgebra.givensAlgorithm(a + 0.0, b + 0.0)
(0.9867864895853341289209073557253047327730824456604417981828460524574, 0.16202599782705631811014484204531232306332276631716101810582393526185, -1.856716447301842229857166517332863678675402899388356814331887958311e-50)

julia> c*c+s*s-1
-1.1543976842476927e-64

But... I can't write my algorithm like that.

How can I avoid loss of accuracy?

dzhang314 commented 6 months ago

Hey @haampie, thank you for bringing this to my attention! This looks like an unfortunate case where my lax approach to renormalization breaks things... At the moment I'm not sure if there's a simple tweak that can fix this without having performance implications in other cases, but I'll certainly give it a try.

In the meantime, can you provide the full a._limbs and b._limbs tuples for the problematic numbers in c, s, = LinearAlgebra.givensAlgorithm(a, b)? I'd like to see how badly normalized they are, and it might be instructive to see how they interact in the add/mul algorithms. This isn't visible in printouts like 0.9867864895853341289209073557253047327730824456644072969143861256782 because I internally normalize before print calls.

haampie commented 6 months ago

I've finally found a rotation that causes an non-normalized multifloat:

julia> c = Float64x4((1.0, -1.2814460661601042e-50, -1.1456058604534196e-66, -4.018828477571062e-83))

julia> s = Float64x4((1.6009035362320234e-25, -8.470934410604026e-42, 4.900032439512285e-58, -3.05869844561302e-74))

julia> c*c + s*s - 1  # check if preserves norm as a rotation
-1.15647600223948327554941871035198256900584233789933e-64

julia> x = Float64x4((-2.3017404993032726e-25, -1.8187505516645134e-41, 4.637866565216834e-58, -2.3866542964252726e-74))

julia> y = Float64x4((1.4377758854357827, -9.834411896819007e-17, 4.676677591256931e-33, 7.260162680046171e-50))

julia> z = c * x + s * y
-1.8321827051883379215672299273506959629434104035485e-50

julia> z._limbs
(-1.832182705188338e-50, 6.742120038269503e-67, -5.263781917376804e-74, 0.0)

julia> z + 0.0 === z
false
dzhang314 commented 1 week ago

Hey @haampie! I had a chance to dig into this today, and unfortunately, I don't think it's possible to fix this issue without significant performance impact to MultiFloats.jl.

First, just to make sure we're on the same page, I've verified that there is no loss of accuracy occurring in the calculation z = c * x + s * y. If you run the following Julia script:

using MultiFloats

# cover full Float64 exponent range, including subnormals
setprecision(exponent(floatmax(Float64)) - exponent(floatmin(Float64)) + precision(Float64))
c = Float64x4((1.0, -1.2814460661601042e-50, -1.1456058604534196e-66, -4.018828477571062e-83))
s = Float64x4((1.6009035362320234e-25, -8.470934410604026e-42, 4.900032439512285e-58, -3.05869844561302e-74))
x = Float64x4((-2.3017404993032726e-25, -1.8187505516645134e-41, 4.637866565216834e-58, -2.3866542964252726e-74))
y = Float64x4((1.4377758854357827, -9.834411896819007e-17, 4.676677591256931e-33, 7.260162680046171e-50))

cx = c * x
cx_big = big(c) * big(x)
println("Accurate bits in cx: ", round(-log2(abs(cx_big - cx) / abs(cx_big))))

sy = s * y
sy_big = big(s) * big(y)
println("Accurate bits in sy: ", round(-log2(abs(sy_big - sy) / abs(sy_big))))

cxpsy = cx + sy
cxpsy_big = big(cx) + big(sy)
println("Accurate bits in cx+sy: ", round(-log2(abs(cxpsy - cxpsy_big) / abs(cxpsy_big))))

You should get this output:

Accurate bits in cx: 216.0
Accurate bits in sy: 216.0
Accurate bits in cxpsy: Inf

So, the pathology here is not in the value of cxpsy, but the fact that it is very badly denormalized. This happens because cx and sy are very close in magnitude. If you compare cx._limbs to sy._limbs:

(-2.3017404993032726e-25, -1.8187505516645134e-41, 4.637866565216834e-58, -2.091698665609915e-74)
(2.3017404993032726e-25, 1.8187505498323308e-41, -1.1620120442135492e-57, -3.1720832517668885e-74)

You can see that their first limbs are equal, and their second limbs are nearly equal. These values propagate through +(::Float64, ::Float64) in an unfortunate way, which we can see if we look at the limbs of the sum prior to normalization:

using MultiFloats: two_sum, _accurate_sum

function debug_add(a, b)
    (a1, a2, a3, a4) = a._limbs
    (b1, b2, b3, b4) = b._limbs
    (s1, e1) = two_sum(a1, b1)
    (s2, e2) = two_sum(a2, b2)
    (s3, e3) = two_sum(a3, b3)
    (s4, e4) = two_sum(a4, b4)
    (x1,) = _accurate_sum(Val{1}(), s1)
    (x2, t1) = _accurate_sum(Val{2}(), e1, s2)
    (x3, t2, t3) = _accurate_sum(Val{3}(), e2, s3, t1)
    (x4, t4) = _accurate_sum(Val{2}(), e3, s4, t2)
    (x5,) = _accurate_sum(Val{1}(), e4, t3, t4)
    return (x1, x2, x3, x4, x5)
end

After defining this function, you should see that debug_add(cx, sy) returns:

(0.0, -1.8321826353657992e-50, -6.982253876918658e-58, -5.263781917376804e-74, 0.0)

Here, the first and second components are unexpectedly small compared to the input operands, and this denormalization is not corrected by the default MultiFloats._two_pass_renorm strategy. It needs to be applied twice:

julia> MultiFloats._two_pass_renorm(Val{4}(), 0.0, -1.8321826353657992e-50, -6.982253876918658e-58, -5.263781917376804e-74, 0.0)
(-1.832182705188338e-50, 6.742120038269503e-67, -5.263781917376804e-74, 0.0)

julia> MultiFloats._two_pass_renorm(Val{4}(), -1.832182705188338e-50, 6.742120038269503e-67, -5.263781917376804e-74, 0.0)
(-1.832182705188338e-50, 6.742119511891311e-67, 4.556505092611046e-83, 0.0)

I've designed MultiFloats._two_pass_renorm to be fast in the happy path where a + b is roughly on a similar order of magnitude to a and b. Conventionally, we expect floating-point pathologies to occur when a and b are roughly equal with opposite sign, causing destructive cancellation to occur in a + b, so I intentionally deprioritized this case.

I think the easiest way to fix this is to manually call MultiFloats.renormalize whenever you have a sum that you expect to be nearly zero. This is not usually a problem because denormalization is self-correcting in sums -- if you accumulate a large number of terms, it washes itself out. As you observed, this only causes issues when you use a denormalized number in a product.

In the next release of MultiFloats.jl, I've added an overload for MultiFloats.renormalize that is simply a no-op for Real numbers other than MultiFloat types, making it safe to call MultiFloats.renormalize in generic code.

If this is unacceptable for your application, I can also provide a PreciseFloat64xN type with all arithmetic operations defined to be strongly normalizing. This comes at a significant performance cost. On my machine, a standard Float64x4 add takes 6.5ns, while a strongly normalizing add takes 15ns. However, if your application really needs this, it is something I am happy to support.

edljk commented 6 days ago

A safer type PreciseFloat64xN would be great! Thanks a lot.

dzhang314 commented 2 days ago

@edljk Sure thing! While I'm working on PreciseFloat64xN, let me make one more remark: I expect strong normalization after every arithmetic operation to be quite expensive. In most applications, I expect it to be sufficient and faster to just add one limb.

For example, if you're having problems with Float64x4, try going to Float64x5 instead of PreciseFloat64x4. Performing renormalization with extra limb may help with destructive cancellation, and conversion between Float64xM and Float64xN is highly optimized (just cuts off limbs or appends zeroes, purely in registers, no dynamic memory allocation).