Open xgdgsc opened 4 years ago
FWIW, I find that in this case something like
sum(x -> !isnan(x) * x, arr, dims=1)
(i.e., without using NaNMath) works just as well. (And is maybe a bit faster than a NaNMath implementation since it relies on the built-in sum
?)
nice solution for sum
, but unfortunately doesn't work for, e.g., mean
.
riffing on julia Base and daneel's code:
using Statistics, Test
_nanfunc(f, A, ::Colon) = f(filter(!isnan, A))
_nanfunc(f, A, dims) = mapslices(a->_nanfunc(f,a,:), A, dims=dims)
nanfunc(f, A; dims=:) = _nanfunc(f, A, dims)
A = [1 2 3; 4 5 6; 7 8 9; NaN 11 12]
@test isapprox(nanfunc(mean, A), mean(filter(!isnan, A)))
@test nanfunc(mean, A, dims=1) == [4.0 6.5 7.5]
@test nanfunc(mean, A, dims=2) == transpose([2.0 5.0 8.0 11.5])
@test isapprox(nanfunc(var, A), var(filter(!isnan, A)))
@test nanfunc(var, A, dims=1) == [9.0 15.0 15.0]
@test nanfunc(var, A, dims=2) == transpose([1.0 1.0 1.0 0.5])
can we actually make this a PR? one issue I see is that mapslices
doesn't play with @view
nicely so at the moment if you actually use dims
you would slow down significantly and have huge allocations:
julia> a = rand([NaN, 1,2,3,4,5], 100,100,100);
julia> @btime nanfunc(mean, a);
1.188 ms (4 allocations: 7.63 MiB)
julia> @btime NaNMath.mean(a);
2.035 ms (1 allocation: 16 bytes)
julia> @btime nanfunc(mean, a; dims=2);
10.382 ms (120039 allocations: 11.37 MiB)
riffing on julia Base and daneel's code:
using Statistics, Test _nanfunc(f, A, ::Colon) = f(filter(!isnan, A)) _nanfunc(f, A, dims) = mapslices(a->_nanfunc(f,a,:), A, dims=dims) nanfunc(f, A; dims=:) = _nanfunc(f, A, dims) A = [1 2 3; 4 5 6; 7 8 9; NaN 11 12] @test isapprox(nanfunc(mean, A), mean(filter(!isnan, A))) @test nanfunc(mean, A, dims=1) == [4.0 6.5 7.5] @test nanfunc(mean, A, dims=2) == transpose([2.0 5.0 8.0 11.5]) @test isapprox(nanfunc(var, A), var(filter(!isnan, A))) @test nanfunc(var, A, dims=1) == [9.0 15.0 15.0] @test nanfunc(var, A, dims=2) == transpose([1.0 1.0 1.0 0.5])
Hi @bjarthur! I do have a question about the way to apply a specific function's argument within the nanfunc(). E.g. if we wanted to calculate std() which might be corrected or not, or any other function that needs an extra one or more arguments.
Thanks!
support sum(arr, dims =1) like in standard julia sum to apply on a given axis.