FluxML / MLJFlux.jl

Wrapping deep learning models from the package Flux.jl for use in the MLJ.jl toolbox
http://fluxml.ai/MLJFlux.jl/
MIT License
143 stars 17 forks source link

Add `rng` to model hyperparameters and to `builder` signature #160

Closed ablaom closed 3 years ago

ablaom commented 3 years ago

Currently there is no way for the user to pass a RNG to the builder for weight initialisation purposes, which makes reproducibility limited to reseeding the global RNG.

Also, one cannot parallelise multiple learning curve computations with MLJ.learning_curve for this reason (this also requires a user-specifiable RNG).

Breaking the builder signature is will be breaking and I don't really see a way around this.

edit Actually, we could have two signatures for backwards-compatibility; if the new signature is not implemented, it falls back to the old one, ie the rng specified by the user plays no role in weight intialization.

cc @bkamins

ablaom commented 3 years ago

@bkamins see my edit above