FluxML / FluxML-Community-Call-Minutes

The FluxML Community Team repo
50 stars 4 forks source link

Constructor for `collect(iterable layers)` #36

Closed darsnack closed 3 years ago

darsnack commented 3 years ago

I came across this use case when working on FluxPrune.jl. If I have an array of functions that operate on layers, then I can do map((f, m) -> f(m), fs, model) when model is iterable like Chain or Parallel. But this returns a vector of layers instead of the type of model. For Chain, I can just pass the output of map back into the constructor, but this isn't the case for Parallel.

Of course, I can special case fs::AbstractVector and model::Chain in my case, but it would be nice to have a more generic way to do this. Basically something like re from Functors.jl but for the output of iterate(model).

ToucheSir commented 3 years ago

If I'm understanding the ask correctly, this should be doable with functor:

layers, re = functor(model)
transformed_layers = map((f, m) -> f(m), fs, layers)
model = re(transformed_layers)
darsnack commented 3 years ago

That's what I went to at first too, but the children from functor are not the same as the elements returned by iterable in general.

ToucheSir commented 3 years ago

That's certainly true, but I'm coming up blank for examples in the wild where it applies. For example, the only iterable layer in https://github.com/FluxML/Flux.jl/blob/master/src/layers/basic.jl is Chain. Parallel doesn't, and I do wonder if you'd want to @forward Parallel.layers Base.iterate because Parallel.connection may itself be a trainable layer.

darsnack commented 3 years ago

Parallel.connection may itself be a trainable layer.

True, so maybe iterating layers is the wrong abstraction. Maybe the right one is a collection of functions with the same structure as the model itself. The use-case I'm describing is how do you selectively apply functions to parts of a model in a generic way?

There's two parts I guess. First, a convenient way to write this structure of functions. Second, basically a generic version of update/apply from Optimisers.jl. Where you re urse both structures together and call each function on each node.