Open gabrielpreviato opened 1 year ago
Yes, this is expected given the current implementation. At https://github.com/JuliaStats/Statistics.jl/issues/22 we fixed overflow with integer types, but we haven't discussed the issue of precision of floating-point types.
In the second example you give, the main issue is that Float16(5e4)/100_000
gives zero because Float16(100_000)
is infinite. We could perform the final division in Float64
without a significant cost, before converting back the result to Float16
. That would return a value close to 0.5, though not as precise as doing the summation in Float64
.
@stevengj @StefanKarpinski Do you think we should do something about this?
Off the top of my head, I don't really see a good way to do an accurate mean
computation of a big array entirely in Float16
, especially if length(x)
is not representable in Float16
. If you're going to promote anything, why not just compute the whole sum in Float64
or Float32
arithmetic?
Wouldn't it be reasonable to just have a specialized mean
(etc.) implementation for AbstractArray{<:Union{Float16,ComplexF16}}
that does all the intermediate calculations in Float64
?
Given that sum
does computations in Float16
, I'm not sure there's a good rationale for mean
doing something different, since they both return a Float16
. Where mean
differs is that it divides by the length of the array, and that operation gives really absurd results when done in Float16
. That's why I though doing that operation in a wider type could make sense (TBH I'm not even sure why we don't do that by default as it's really dangerous).
One solution that I thought is breaking bigger arrays in smaller chunks. Something like this:
new_mean(A::AbstractArray{<:Union{Float16,ComplexF16}}; dims=:) = _mean_small(identity, A, dims)
function _mean_small(f, A::AbstractArray{<:Union{Float16,ComplexF16}}, dims::Dims=:) where Dims
isempty(A) && return sum(f, A, dims=dims)/0
if dims === (:)
n = length(A)
else
n = mapreduce(i -> size(A, i), *, unique(dims); init=1)
end
x1 = f(first(A)) / 1
if n > 10000
chunks = [collect(1:10000:length(A)); length(A)]
A_views = [view(A, chunks[i]:chunks[i+1]) for i in 1:length(chunks)-1]
result = sum.(x -> _mean_promote(x1, f(x)), A_views, dims=dims) ./ (chunks[2:end] - chunks[1:end-1])
return sum(result .* (chunks[2:end] - chunks[1:end-1])./10000) / (length(chunks)-1)
else
result = sum(x -> _mean_promote(x1, f(x)), A, dims=dims)
end
if dims === (:)
return result / n
else
return result ./= n
end
end
With this, the following results would be obtained:
julia> data=rand(Float16, 10^4)
10000-element Vector{Float16}:
julia> sum(data)
Float16(4.97e3)
julia> mean(data)
Float16(0.4968)
julia> new_mean(data)
Float16(0.4968)
julia> data=rand(Float16, 10^5)
100000-element Vector{Float16}:
julia> sum(data)
Float16(4.992e4)
julia> mean(data)
Float16(0.0)
julia> new_mean(data)
Float16(0.4993)
julia> data=rand(Float16, 10^6)
1000000-element Vector{Float16}:
julia> sum(data)
Inf16
julia> mean(data)
NaN16
julia> new_mean(data)
Float16(0.4998)
Doing some performance benchmarks:
julia> using BenchmarkTools
julia> @benchmark mean(data) setup=(data=rand(Float16, 10^5))
BenchmarkTools.Trial: 8926 samples with 1 evaluation.
Range (min … max): 392.300 μs … 989.200 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 397.000 μs ┊ GC (median): 0.00%
Time (mean ± σ): 400.816 μs ± 20.073 μs ┊ GC (mean ± σ): 0.00% ± 0.00%
▄█ ▇ ▃
██▇█▆█▆▅▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
392 μs Histogram: frequency by time 475 μs <
Memory estimate: 0 bytes, allocs estimate: 0.
julia> @benchmark new_mean(data) setup=(data=rand(Float16, 10^5))
BenchmarkTools.Trial: 9033 samples with 1 evaluation.
Range (min … max): 393.200 μs … 914.700 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 397.800 μs ┊ GC (median): 0.00%
Time (mean ± σ): 400.267 μs ± 16.346 μs ┊ GC (mean ± σ): 0.00% ± 0.00%
▅ █▂ ▂
█▅▃████▆▅▆▄▃▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
393 μs Histogram: frequency by time 445 μs <
Memory estimate: 2.14 KiB, allocs estimate: 29.
@benchmark mean(data) setup=(data=rand(Float16, 10^6))
BenchmarkTools.Trial: 905 samples with 1 evaluation.
Range (min … max): 3.942 ms … 7.279 ms ┊ GC (min … max): 0.00% … 0.00%
Time (median): 4.096 ms ┊ GC (median): 0.00%
Time (mean ± σ): 4.139 ms ± 236.673 μs ┊ GC (mean ± σ): 0.00% ± 0.00%
▄█▅▆▄▆▃▂▂▄▅▃▂▁
██████████████▇█▆▇▅▅▅▃▄▄▅▃▄▃▃▃▂▂▃▂▂▂▂▂▁▂▃▁▂▁▂▂▁▁▁▁▁▂▂▁▂▂▁▁▂ ▄
3.94 ms Histogram: frequency by time 4.95 ms <
Memory estimate: 0 bytes, allocs estimate: 0.
@benchmark new_mean(data) setup=(data=rand(Float16, 10^6))
BenchmarkTools.Trial: 923 samples with 1 evaluation.
Range (min … max): 3.939 ms … 4.956 ms ┊ GC (min … max): 0.00% … 0.00%
Time (median): 4.027 ms ┊ GC (median): 0.00%
Time (mean ± σ): 4.056 ms ± 110.197 μs ┊ GC (mean ± σ): 0.00% ± 0.00%
▃▄▆▄▇█▆▆▇▄▂▂
▃▅████████████▇▇▅▆▄▄▃▄▃▃▃▄▄▄▃▂▃▃▃▂▃▃▂▁▃▃▂▂▂▂▁▂▂▃▂▂▂▂▁▂▂▃▁▂▃ ▄
3.94 ms Histogram: frequency by time 4.48 ms <
Memory estimate: 11.94 KiB, allocs estimate: 29.
It increases quite the memory allocation, but this can be improved by replacing the chunk calculation to some loop I think.
I'm not sure there's a good rationale for mean doing something different
Because the result of a mean
of a large array can be within the range of a Float16
, whereas the result of a sum
is not. It's the same reason as why norm
is not implemented by sqrt(sum(abs2, array))
— we try to arrange our algorithms to avoid spurious overflow of intermediate results if the final result should be (approximately) representable.
I feel like we already resolved this argument with #25 and this long discourse thread?
Of course, you could also have spurious overflow with Float32
or Float64
:
julia> mean(floatmax(Float32) * [1,1]) # should be 3.4028235f38
Inf32
julia> mean(floatmax(Float64) * [1,1]) # should be 1.7976931348623157e308
Inf
but this seems like a less pressing problem, as it's much less likely to happen unexpectedly than for Float16
.
@stevengj Yeah we've had a similar discussion before, but then the argument was that we should accumulate using the type that mean
returns. Here the problem is a bit different. The fact that the sum of a large Float16
array is unlikely to fit in a Float16
seems to indicate that sum
should return a wider type like Float32
or Float64
-- just like we do for Int16
. I wish Julia did that, but https://github.com/JuliaLang/julia/issues/20560 discussed only integer types, and now it cannot be changed until Julia 2.0.
Anyway you have a point that since the mean of Float16
values can by definition fit in a Float16
, we can change the implementation to accumulate using a wider type, and continue returning a Float16
, without breaking API -- contrary to sum
. So that's probably the best solution.
@gabrielpreviato Your approach will help a bit, but it won't fix overflow if the sum of values in a block exceed floatmax(Float16)
, which is not unlikely given how small that value is.
I tried implementing accumulation using a wider type, but there's a problem: for heterogeneous collections, it's almost impossible to compute efficiently the type to which the result should be converted back. For example, with Any[Float16(1.0), Float32(2.0)]
, the fact that the result should be Float32
cannot be determined easily if we accumulate in Float64
. It's easy (though costly) to keep track of the type when summing using a loop, but for array inputs we call sum
to use pairwise summation, and I don't see a way to make it compute that type.
The only solution I can see is to do a separate pass just to compute the type. Maybe that's OK since that will only affect heterogeneous arrays (which are slow anyway).
Why not implement something that is numerically stable? https://discourse.julialang.org/t/why-are-missing-values-not-ignored-by-default/106756/286
See also https://github.com/JuliaLang/julia/issues/52365. That would indeed fix the overflow with Float16
, but it would hurt performance significantly. It would be faster and more accurate to accumulate in Float64
.
Since the current mean implementation is calculated by summing all elements and then dividing it by the total number of elements, when working with smaller types (such as Float16) it's pretty easy to fall into an overflow when dealing with bigger arrays, as you can see in the following example:
An easy solution when facing this is using Float32 instead, but I wanted to point out this issue when using Float16.