FluxML / Flux.jl

Relax! Flux is the ML library that doesn't make you tensor
https://fluxml.ai/
Other
4.5k stars 607 forks source link

Feature request: Modifying Dense Layer to accommodate kernel/bias constraints and kernel/bias regularisation #1389

Open yewalenikhil65 opened 3 years ago

yewalenikhil65 commented 3 years ago

Hi, I feel Dense layer should be more armed with arguments like kernel/weights constraints, bias constraints, kernel/weights regularisation and bias regularisation as is available in Tensorflow. Weight constraints and bias constraints could help in avoiding overfitting

tf.keras.layers.Dense(
    units, activation=None, use_bias=True, kernel_initializer='glorot_uniform',
    bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None,
    activity_regularizer=None, kernel_constraint=None, bias_constraint=None,
    **kwargs)

Some well-known weight/bias constraints include NonNeg (to make weights non-negative) MaxNorm, MinMaxNorm and UnitNorm as documented at https://keras.io/api/layers/constraints/

jeremiedb commented 3 years ago

This is just a personal impression, but my understanding is that the Flux philosophy would be to handle such options at the loss/optimization level. For example, the examples presented here https://fluxml.ai/Flux.jl/stable/models/regularisation/ shows how ad hoc constraints on parameters could be applied. Also, https://fluxml.ai/Flux.jl/stable/models/advanced/ shows how constraints could be applied through params freezing.

By doing so, I think it helps keeping the control flexible and applicable to any operators without the need to load numerous arguments to each of the operator, which is a plus to the Flux experience in my opinion.

yewalenikhil65 commented 3 years ago

Hi @jeremiedb I think you are right. But does this ad-hoc constraint using params (as mentioned in documentation you Linked) is applied during the whole training process when we use Flux.train! ? I think not.

I had little difficulty in understanding this

CarloLucibello commented 3 years ago

You can pass a callback to the train! function or define your custom training loop https://fluxml.ai/Flux.jl/stable/training/training/#Custom-Training-loops-1

Typically constraints are implemented by contracting the weights to the constrained space after each update, e.g.

for p in params(model)
    p .= clamp(p, -2, 2)
end

or reparametrizing the weights. The former is what keras' constraints do. From the page you linked: " They are per-variable projection functions applied to the target variable after each gradient update (when using fit())."

In the latter case instead, you have to define your own layer. See weight norm https://github.com/FluxML/Flux.jl/pull/1005 for an incomplete attempt to extend the reparametrization to all layers.

yewalenikhil65 commented 3 years ago

@CarloLucibello Thank you for your suggestion. Did you mean following way of passing a callback function ?


ps = Flux.params(model)
cb = function() #callback function
    ps[1] .= abs.(ps[1])  # to , "say" only consider positive or absolute values of first layer

       # or clamping in 0.0 and 1.0 interval
       ps[1] .= clamp.(ps[1], 0.0 ,1.0)
end
cb()
@epochs args.epochs Flux.train!( loss, ps, train_data, opt, cb = cb)