applied-ml-bde / bde

Bayesian Deep Ensembles
https://applied-ml-bde.github.io/bde/
BSD 3-Clause "New" or "Revised" License
2 stars 0 forks source link

Make `fit` jitable #22

Closed cklyne closed 2 months ago

cklyne commented 2 months ago

The current fit function is non-jitted. Since fit is responsible for preparing the training data and initializing some params it can only be partially jitted. I propose we break it down to 2 parts:

I propose the fit method should look something like this:

def fit(x, y):
  ### input validation ###
  ### initiate internal params (per SKlearn requirements) ###
  ### do non-jittable preparations ###
  self.params_ = jitted_fit(
    self.model_,
    self.params_,
    train,
    validation,
    metrics,
    **kwargs,
  )