Closed Red-Portal closed 2 months ago
Can you elaborate on why you want re(params)
to be type stable? If it's to ensure that subsequent code is type stable, then re(params)::typeof(m)
in user code might work better. If it's to ensure that Restructure
is internally type stable, a type assertion won't be enough.
Hi @ToucheSir ! The first reason is because Enzyme fails due to the instability. And the second reason is that I find it unusual to be type instable. Restructure
is fully aware of the type we expect, so I am actually surprised that the current implementation is not type stable out of the box.
Restructure
is fully aware of the type we expect
Not quite. You are allowed to pass an array of a different eltype to Restructure
, which might give you a different model return type.
I am actually surprised that the current implementation is not type stable out of the box.
...but even if you weren't, the type instability would remain. The reason is because Functors uses an untyped IdDict internally to keep track of shared/aliased parameters: https://github.com/FluxML/Functors.jl/blob/2eddcb74f9589e61b847362c2a91d91bd90ef628/src/walks.jl#L170-L201
The fix here would be to tell fmap
not to cache by passing cache = nothing
. This isn't done by default because we can't assume people are using models with zero parameter sharing/tying, but perhaps the models you're working with can. Disabling the cache would be necessary but may not be sufficient to remove type instabilities, as the Julia compiler really hates nested recursive calls like Functors.jl uses.
P.S. since you're working with Enzyme and not Zygote, you may be interested in adapting the mutating (de|Re)structure
implementation in https://github.com/FluxML/Optimisers.jl/pull/165 for your use case.
Thanks for the tips. Let me try to explicitly set the return type of the reconstruct expression as you suggested.
Hi, it seems like calling restructure is not stable by default. This is currently causing issues with
Enzyme.jl
(see this issue). Here is a MWE to illustrate the point:This returns:
where the return type
Model
is not stable in terms of its type parameters.This can be solved in a brute-force manner by defining
re::Restructure
(defined here) as:where we are informing the compiler that the return type will be the same as
re.model
. I think this is safe to assume, and this immediately resolves the instability. Any thoughts on this?