FluxML / Flux.jl

Relax! Flux is the ML library that doesn't make you tensor
https://fluxml.ai/
Other
4.53k stars 610 forks source link

`count_params` function? #2043

Open ericphanson opened 2 years ago

ericphanson commented 2 years ago

The show methods are super nice, and especially the counts of parameters and arrays. Could functions getting the counts be pulled out as an API function? Looking at the source, I guess that would be something like

using Functors, Flux
using Functors: isleaf

_childarray_sum(f, x::AbstractArray{<:Number}) = f(x)
_childarray_sum(f, x) = isleaf(x) ? 0 : sum(y -> _childarray_sum(f, y), Functors.children(x))

function count_params(m)
    ps = params(m)
    pars = sum(length, ps)
    noncnt = _childarray_sum(_->1, m) - length(ps)
    nonparam = _childarray_sum(length, m) - sum(length, ps)
    return (; trainable_arrays=length(ps), trainable_params=pars, non_trainable_arrays=noncnt, non_trainable_params=nonparam)
end
DhairyaLGandhi commented 2 years ago

Sounds like a good idea.