odin-lang / Odin

Odin Programming Language
https://odin-lang.org
BSD 3-Clause "New" or "Revised" License
6.94k stars 611 forks source link

math.round is wrong on some values #2705

Closed ryhor-spivak closed 8 months ago

ryhor-spivak commented 1 year ago

Context

Expected Behavior

I expect this program:

package main

import "core:fmt"
import "core:math"

round_test :: proc(x: $FP)
{
    rx := math.round(x)
    fmt.printf("%.17f(%h) rounds to %f(%h)\n", x, x, rx, rx )
}

main :: proc()
{
    round_test(f16(0.4998))
    round_test(f16(1025))
    round_test(f16(-0.4998))
    round_test(f16(-1025))

    round_test(f32(0.49999997))
    round_test(f32(8388611))
    round_test(f32(-0.49999997))
    round_test(f32(-8388611))

    round_test(f64(0.49999999999999994))
    round_test(f64(4503599627370497))
    round_test(f64(-0.49999999999999994))
    round_test(f64(-4503599627370497))
}

to produce this output:

0.49975585937500000(0h37ff) rounds to 0.000(0h0)
1025.00000000000000000(0h6401) rounds to 1025.000(0h6401)
-0.49975585937500000(0hb7ff) rounds to -0.000(0h8000)
-1025.00000000000000000(0he401) rounds to -1025.000(0he401)
0.49999997019767761(0h3effffff) rounds to 0.000(0h0)
8388611.00000000000000000(0h4b000003) rounds to 8388611.000(0h4b000003)
-0.49999997019767761(0hbeffffff) rounds to -0.000(0h80000000)
-8388611.00000000000000000(0hcb000003) rounds to -8388611.000(0hcb000003)
0.49999999999999994(0h3fdfffffffffffff) rounds to 0.000(0h0)
4503599627370497.00000000000000000(0h4330000000000001) rounds to 4503599627370497.000(0h4330000000000001)
-0.49999999999999994(0hbfdfffffffffffff) rounds to -0.000(0h8000000000000000)
-4503599627370497.00000000000000000(0hc330000000000001) rounds to -4503599627370497.000(0hc330000000000001)

Current Behavior

But currently, it produces this:

0.49975585937500000(0h37ff) rounds to 1.000(0h3c00)
1025.00000000000000000(0h6401) rounds to 1026.000(0h6402)
-0.49975585937500000(0hb7ff) rounds to -1.000(0hbc00)
-1025.00000000000000000(0he401) rounds to -1026.000(0he402)
0.49999997019767761(0h3effffff) rounds to 1.000(0h3f800000)
8388611.00000000000000000(0h4b000003) rounds to 8388612.000(0h4b000004)
-0.49999997019767761(0hbeffffff) rounds to -1.000(0hbf800000)
-8388611.00000000000000000(0hcb000003) rounds to -8388612.000(0hcb000004)
0.49999999999999994(0h3fdfffffffffffff) rounds to 1.000(0h3ff0000000000000)
4503599627370497.00000000000000000(0h4330000000000001) rounds to 4503599627370498.000(0h4330000000000002)
-0.49999999999999994(0hbfdfffffffffffff) rounds to -1.000(0hbff0000000000000)
-4503599627370497.00000000000000000(0hc330000000000001) rounds to -4503599627370498.000(0hc330000000000002)

My attempt to fix this (https://github.com/odin-lang/Odin/pull/2675) was closed as incorrect, so this should be fixed is some other way.

ryhor-spivak commented 1 year ago

I still don't know what is wrong with my fix. You can just loop through all bit patterns of f16 and f32 and check it works as expected:

package main

import "core:fmt"
import "core:math"

@(require_results)
round_f16_fix :: proc "contextless" (x: f16)   -> f16 {
    return math.ceil(x - 0.4998) if x < 0 else math.floor(x + 0.4998)
}

@(require_results)
round_f32_fix :: proc "contextless" (x: f32)   -> f32 {
    return math.ceil(x - 0.49999997) if x < 0 else math.floor(x + 0.49999997)
}

@(require_results)
round_f64_fix :: proc "contextless" (x: f64)   -> f64 {
    return math.ceil(x - 0.49999999999999994) if x < 0 else math.floor(x + 0.49999999999999994)
}

round_fix :: proc{
    round_f16_fix,
    round_f32_fix,
    round_f64_fix,
}

round_check :: proc(x: $FP)
{
    rx := round_fix(x)
    //rx := math.round(x) // this will show all bad cases of current math.round version

    if math.is_nan(x) && !math.is_nan(rx)
    {
        fmt.printf("fail on %.9f(%h): rounds to %f(%h)\n", x, x, rx, rx )
        return
    }

    f := math.floor(x)
    c := math.ceil(x)

    if x > 0
    {
        if x - f < c - x
        {
            if rx != f do fmt.printf("fail on %.9f(%h): rounds to %f(%h)\n", x, x, rx, rx )
        }
        else
        {
            if rx != c do fmt.printf("fail on %.9f(%h): rounds to %f(%h)\n", x, x, rx, rx ) 
        }
    }
    else
    {
        if x - f <= c - x
        {
            if rx != f do fmt.printf("fail on %.9f(%h): rounds to %f(%h)\n", x, x, rx, rx )
        }
        else
        {
            if rx != c do fmt.printf("fail on %.9f(%h): rounds to %f(%h)\n", x, x, rx, rx ) 
        }
    }
}

main :: proc()
{
    fmt.printf("checking all f16:\n")
    {
        u : u16 = 0
        for
        {
            round_check(transmute(f16)u)
            u += 1
            if u == 0 do break
        }
    }

    fmt.printf("checking all f32:\n")
    {
        u : u32 = 0
        for
        {
            round_check(transmute(f32)u)
            u += 1
            if u == 0 do break
        }
    }
}
ListeriaM commented 1 year ago

round(-0.0) should return -0.0, also I don't know if you can rely on 0.5 + 0.499... == 1

ryhor-spivak commented 1 year ago

-0 can be handled by

@(require_results)
round_f16_fix :: proc "contextless" (x: f16)   -> f16 {
    return math.copy_sign(math.floor(abs(x) + 0.4998), x)
}

@(require_results)
round_f32_fix :: proc "contextless" (x: f32)   -> f32 {
    return math.copy_sign(math.floor(abs(x) + 0.49999997), x)
}

@(require_results)
round_f64_fix :: proc "contextless" (x: f64)   -> f64 {
    return math.copy_sign(math.floor(abs(x) + 0.49999999999999994), x)  
}

And this probably will be faster too.

For relying on fp addition: in all other places we already rely on it rounding to nearest representable float with ties to even.