Closed MilesCranmer closed 2 years ago
I can take a look at this. I'm not too familiar with the whole code base yet, but right now it seems to more or less to trace the use of the operations defined in src/Operators.jl
and replace the occurrences with NaNMath.log
and so on. If you have any other thoughts on the approach before I start let me know.
I also wanted to ask, I just did a quick naive benchmark on NaNMath
versus an own implementation and NaNMath
seems to perform consistently a bit worse,
julia> using NaNMath
julia> function log_nan(x)
x < 0.0 && return NaN
return log(x)
end
log_nan (generic function with 1 method)
julia> X = rand(50_000_000) .* 2 .- 1;
julia> @time NaNMath.log.(X);
0.661889 seconds (132.60 k allocations: 388.716 MiB, 0.80% gc time, 6.57% compilation time)
julia> @time NaNMath.log.(X);
0.685413 seconds (4 allocations: 381.470 MiB)
julia> @time NaNMath.log.(X);
0.614610 seconds (4 allocations: 381.470 MiB)
julia> @time log_nan.(X);
0.644774 seconds (117.23 k allocations: 387.860 MiB, 12.58% gc time, 11.42% compilation time)
julia> @time log_nan.(X);
0.543851 seconds (4 allocations: 381.470 MiB)
julia> @time log_nan.(X);
0.477451 seconds (4 allocations: 381.470 MiB, 3.61% gc time)
Comparing the two best times above, NaNMath
performs some 29% slower. Maybe this is not something significant enough (or maybe I'm too naive in trusting this benchmark), and maybe it is easier to just go with NaNMath
anyway, I just wanted to ask it that is still the desired solution.
Thanks, I appreciate the help, @johanbluecreek!
Very curious result. It might be worth raising on issue on the NaNMath.jl
to let them know about this? It's probably easier to use in-house versions for common operators though, since it definitely looks like NaNMath
slows things down.
Also, a couple points:
NaN
of the same type. i.e., function log_nan(x::T)::T where {T<:Real}
x <= T(0) && return T(NaN)
return log(x)
end
BenchmarkTools.@btime
instead - it gives more reliable estimates.Thanks again! Miles
Also, after defining new operators with custom names like log_nan
, you could map their function name to simply log
when printed, by modifying https://github.com/MilesCranmer/SymbolicRegression.jl/blob/29f6bf19d0920f74fb79431418d3e34ea2c4af75/src/Equation.jl#L204
Also, if the user passes log
as one of the operators, we want to map it to log_nan
- that would use these functions here:
https://github.com/MilesCranmer/SymbolicRegression.jl/blob/29f6bf19d0920f74fb79431418d3e34ea2c4af75/src/Options.jl#L86-L120
This was completed with #123 . (Thanks again!)
Operators should be default use NaNMath https://github.com/mlubin/NaNMath.jl, instead of my versions which simply map input to the valid domain like
log(abs(x))
. Since the evaluator already detects NaNs by default, this should work well and expressions will automatically try to avoid having improper inputs.