Closed darsnack closed 3 years ago
If I'm understanding the ask correctly, this should be doable with functor
:
layers, re = functor(model)
transformed_layers = map((f, m) -> f(m), fs, layers)
model = re(transformed_layers)
That's what I went to at first too, but the children from functor
are not the same as the elements returned by iterable
in general.
That's certainly true, but I'm coming up blank for examples in the wild where it applies. For example, the only iterable layer in https://github.com/FluxML/Flux.jl/blob/master/src/layers/basic.jl is Chain
. Parallel doesn't, and I do wonder if you'd want to @forward Parallel.layers Base.iterate
because Parallel.connection
may itself be a trainable layer.
Parallel.connection may itself be a trainable layer.
True, so maybe iterating layers is the wrong abstraction. Maybe the right one is a collection of functions with the same structure as the model itself. The use-case I'm describing is how do you selectively apply functions to parts of a model in a generic way?
There's two parts I guess. First, a convenient way to write this structure of functions. Second, basically a generic version of update
/apply
from Optimisers.jl. Where you re urse both structures together and call each function on each node.
I came across this use case when working on FluxPrune.jl. If I have an array of functions that operate on layers, then I can do
map((f, m) -> f(m), fs, model)
whenmodel
is iterable likeChain
orParallel
. But this returns a vector of layers instead of the type ofmodel
. ForChain
, I can just pass the output ofmap
back into the constructor, but this isn't the case forParallel
.Of course, I can special case
fs::AbstractVector
andmodel::Chain
in my case, but it would be nice to have a more generic way to do this. Basically something likere
from Functors.jl but for the output ofiterate(model)
.