brianguenter / FastDifferentiation.jl

Fast derivative evaluation
MIT License
118 stars 4 forks source link

TypeError: non-boolean (FastDifferentiation.Node) used in boolean context #91

Closed RJDennis closed 1 month ago

RJDennis commented 2 months ago

The following used to work, but now throws an error. There seems to be a problem in situations where a variable has another variable as an exponent.

using FastDifferentiation

@variables x y

f = x^y

derivative([f],x)
brianguenter commented 2 months ago

I expected some problems with this latest release but that was not one of them. Check out the release notes for 0.4.0:

Short answer: I understand the problem. It won't be fixed immediately for reasons explained below. Your best bet for now is to revert to the previous version 0.3.17. I will try to get a patch release that reverts back to 0.3.17 since nobody should be using 0.4.0.

Long answer: I am in the process of adding conditionals to FastDifferentiation. This is major surgery so I broke it into two parts: adding conditional expressions and correctly differentiating through them. I thought adding conditionals would not cause any FastDifferentiation code to break but might cause problems in user code. Alas, it did break FastDifferentiation.

This PR https://github.com/brianguenter/FastDifferentiation.jl/pull/90#issue-2487795686 added conditional expressions. You can now do this


julia> @variables x y
y

julia> f = ifelse(x<y,x,y)
(ifelse  (x < y) x y)

julia> g = make_function([f],[x,y])

julia> g(1,2)
1-element Vector{Float64}:
 1.0

julia> g(2,1)
1-element Vector{Float64}:
 2.0

But you cannot differentiate through functions that have conditionals. I forgot that some of the differentiation rules in DiffRules.jl use conditionals on the variables. For example, this is the diff rule for ^:

@define_diffrule Base.:^(x, y) = :( $y * ($x^($y - 1)) ), :( ($x isa Real && $x<=0) ? Base.oftype(float($x), NaN) : ($x^$y)*log($x) )

There is no way to know whether x <= 0 at the time you create your expression. This has to be determined at run time when x is assigned a value.

Before this PR some conditionals on expressions were defined in an ad hoc way that made the differentiation work most of the time but caused it to fail in some cases that would seem mysterious to the end user.

Now conditionals on expression return expressions, not bools, and this causes DiffRules to crash when it evaluates $x<=0, expects a bool and gets an expression.

I will fix the rules in DiffRules that used a ? b : c to instead use ifelse(a,b,c) so they correctly return an expression instead of crashing. But you still won't be able to properly compute a derivative through this expression so this won't solve your problem.

brianguenter commented 1 month ago

This PR https://github.com/brianguenter/FastDifferentiation.jl/pull/93#issue-2532311530 has a patch to fix this problem. Now this works:

julia> @variables x y
y

julia> 

julia> f = x^y
(x ^ y)

julia> 

julia> derivative([f],x)
1-element Vector{FastDifferentiation.Node}:
 (y * (x ^ (y - 1)))

julia> derivative([f],y)
1-element Vector{FastDifferentiation.Node}:
 (if_else  (x <= 0) NaN ((x ^ y) * log(x)))

julia> g = derivative([f],y)
1-element Vector{FastDifferentiation.Node}:
 (if_else  (x <= 0) NaN ((x ^ y) * log(x)))

julia> h = make_function(g,[x,y])
...
julia> h(0.0,2.0)
1-element Vector{Float64}:
 NaN

julia> h(1.1,2.0)
1-element Vector{Float64}:
 0.10584521819377114

You still cannot compute the derivative of an expression that contains a conditional. That is coming in a future PR.

RJDennis commented 1 month ago

Excellent. I refer to and illustrate the use of your package in my lecture notes for a course I teach to students in Economics. Thank you for your work. I look forward to the tagged release.

brianguenter commented 1 month ago

Just submitted 0.4.1, which should fix this bug, to the Julia registry. In 15-20 minutes the new version should be up on the repo.

Try it out - if it fixes the bug for you I'll close this issue.

RJDennis commented 1 month ago

Yes, that fixes it. Thanks.