EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
443 stars 62 forks source link

Fix type escaping in `@import_frule`, `@import_rrule` #1446

Closed mofeing closed 4 months ago

mofeing commented 4 months ago

While trying to import some ChainRules from a package, I got the following error:

julia> using Enzyme, OMEinsum

julia> Enzyme.@import_rrule(typeof(einsum), OMEinsum.EinCode, Any, Any)
UndefVarError: `OMEinsum` not defined

Stacktrace:
 [1] macro expansion
   @ ~/.julia/packages/Enzyme/srACB/ext/EnzymeChainRulesCoreExt.jl:181 [inlined]
 [2] top-level scope
   @ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum-from-chainrules.ipynb:1

By expanding the macro, I saw the the type annotations where not correctly escaped. For example, for augmented_primal we get

... augmented_primal(...) where {... var"#378#AN_1" <: Enzyme.Annotation{<:(Enzyme.OMEinsum).EinCode}, var"#379#AN_2" <: Enzyme.Annotation{<:Enzyme.Any}, var"#380#AN_3" <: Enzyme.Annotation{<:Enzyme.Any}}

Check out that Enzyme.OMEinsum and Enzyme.Any are wrong.

I haven't checked it out but this PR should fix it.

wsmoses commented 4 months ago

If possible can you add a test?

On Tue, May 14, 2024 at 9:03 AM Sergio Sánchez Ramírez < @.***> wrote:

While trying to import some ChainRules from a package, I got the following error:

julia> using Enzyme, OMEinsum

julia> @.***_rrule(typeof(einsum), OMEinsum.EinCode, Any, Any) UndefVarError: OMEinsum not defined

Stacktrace: [1] macro expansion @ ~/.julia/packages/Enzyme/srACB/ext/EnzymeChainRulesCoreExt.jl:181 [inlined] [2] top-level scope @ ~/Developer/k-local-gradient-descent/notebooks/enzyme-omeinsum-from-chainrules.ipynb:1

By expanding the macro, I saw the the type annotations where not correctly escaped. For example, for augmented_primal we get

... augmented_primal(...) where {var"#376#RetAnnotation", var"#377#FA" <: Enzyme.Annotation{<:typeof(einsum)}, var"#378#AN_1" <: Enzyme.Annotation{<:(Enzyme.OMEinsum).EinCode}, var"#379#AN_2" <: Enzyme.Annotation{<:Enzyme.Any}, var"#380#AN_3" <: Enzyme.Annotation{<:Enzyme.Any}}

I haven't checked it out but this PR should fix it.

You can view, comment on, or merge this pull request online at:

https://github.com/EnzymeAD/Enzyme.jl/pull/1446 Commit Summary

File Changes

(1 file https://github.com/EnzymeAD/Enzyme.jl/pull/1446/files)

Patch Links:

— Reply to this email directly, view it on GitHub https://github.com/EnzymeAD/Enzyme.jl/pull/1446, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXCW4UO5RI762JL6HW3ZCIYWRAVCNFSM6AAAAABHWPJXDWVHI2DSMVQWIX3LMV43ASLTON2WKOZSGI4TKOBVHEYTCOA . You are receiving this because you are subscribed to this thread.Message ID: @.***>

mofeing commented 4 months ago

Added! It tests against a MockType in a MockModule.

wsmoses commented 4 months ago

@mofeing this fails CI

mofeing commented 4 months ago

Forgot to add methods for fdiff and rdiff. Should be fixed now.

mofeing commented 4 months ago

I fixed all errors in the tests except one: The rrule test of MockType.

The pullback in rrule defined for mock_function should return a number

function ChainRulesCore.rrule(::typeof(MockModule.mock_function), x)
    y = MockModule.mock_function(x)
    return y, ȳ -> 2 * ȳ
end

But it seems like Enzyme is returning a MockType. Is this because we are annotating the return activity to be Active?

rdiff(f, x::MockModule.MockType) = autodiff(Reverse, f, Active, Active(x))[1][1]

Enzyme.@import_rrule typeof(MockModule.mock_function) MockModule.MockType
@test rdiff(MockModule.mock_function, MockModule.MockType(1f0)) === 2f0

...

import_rrule: Test Failed at /Users/mofeing/Developer/Enzyme.jl/test/ext/chainrulescore.jl:117
  Expression: rdiff(MockModule.mock_function, MockModule.MockType(1.0f0)) === 2.0f0
   Evaluated: Main.MockModule.MockType(2.0f0) === 2.0f0
wsmoses commented 4 months ago

@mofeing CI still fails "import_rrule: Test Failed at /home/runner/work/Enzyme.jl/Enzyme.jl/test/ext/chainrulescore.jl:117 Expression: rdiff(MockModule.mock_function, MockModule.MockType(1.0f0)) === 2.0f0 Evaluated: Main.MockModule.MockType(2.0f0) === 2.0f0 "

wsmoses commented 4 months ago

and no, for active vals, Enzyme returns a value of the same type (whatever it is)

mofeing commented 4 months ago

Okay, fixed now.

codecov-commenter commented 4 months ago

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 72.43%. Comparing base (cc8ceb6) to head (8130415). Report is 2 commits behind head on main.

:exclamation: Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #1446 +/- ## ========================================== + Coverage 68.04% 72.43% +4.39% ========================================== Files 30 30 Lines 10772 10838 +66 ========================================== + Hits 7330 7851 +521 + Misses 3442 2987 -455 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.