ITensor / ITensorMPS.jl

MPS and MPO methods based on ITensor (ITensors.jl)
Apache License 2.0
20 stars 4 forks source link

[ITensors] [BUG] logdot and lognorm are not compatible with AD due to try/catch #77

Open ArtemStrashko opened 2 years ago

ArtemStrashko commented 2 years ago

Description of bug

Due to try/catch in lognorm and logdot, they are not supported by AD.

Minimal code demonstrating the bug or unexpected behavior

Minimal runnable code

```julia l1 = x -> logdot(x, b) l2 = x -> lognorm(x, b) l3 = x -> norm(x, b) inds = [Index(2) for _ in 1:10] a = randomMPS(inds) b = randomMPS(inds) l1'(a) l2'(a) l3'(a) ```

Expected output or behavior

Return a valid gradient.

Actual output or behavior

Output of minimal runnable code

```julia Compiling Tuple{ITensors.var"##_log_or_not_dot#663", Bool, typeof(ITensors._log_or_not_dot), MPS, MPS, Bool}: try/catch is not supported. ```

Version information

mtfishman commented 2 years ago

Thanks for the report @ArtemStrashko. I think you mean l2 = x -> lognorm(x) in the example you show?

I was thinking of handling this by defining custom rrules for logdot/lognorm by using the chain rule for log(dot(x, y)) and log(norm(x)) and making use of the fact that dot(x, y) and norm(x) already have derivatives defined. So basically define the rrule by explicitly using that d(logdot(x, y))/dx = inv(dot(x, y)) * d(dot(x, y))/dx and use the result for d(dot(x, y))/dx from the rrule written for that function.

I won't have time to investigate this right now but feel free to take a look at it if you are interested.