GibbsConsulting / jupyter-plotly-dash

Jupyter notebook wrapper for plotly dash applications
GNU Affero General Public License v3.0
81 stars 12 forks source link

Performance drop when using Scikit-learn #58

Open VNDRN opened 4 years ago

VNDRN commented 4 years ago

When training Scikit-learn models I noticed that the code finishes significantly faster in a normal cell. When running the same function in a JupyterDash app, it takes up to 20% more time to finish.

I tested it with the random forest regression and support vector regression models using multiprocessing. Can JupyterDash not use multiprocessing? Or does a regular codecell in a notebook have more resources than the code run within a JupyterDash app?

GibbsConsulting commented 4 years ago

Are you running the code inside a callback?

There is nothing special about JupyterDash although it does add a listener onto a port for the dash callbacks; this is on the (default) asyncio event loop. I don't know if Scikit-learn changes or modifies this at all.

Are you in a position to share an example?

VNDRN commented 4 years ago

The function that runs the training of the model is declared outside of the app. It is called however from within a callback. I will try to provide an example but have to redact some info due to a NDA with my workplace

from joblib import Parallel, delayed
import multiprocessing
from sklearn.ensemble import RandomForestRegressor

def forestTrainer(amount):
    t1 = datetime.now()
    model = RandomForestRegressor(n_estimators=amount)
    def trainPeriod period(i):
        model.fit(train_data)
        test_predict = model.predict(test_data)
        return mean_absolute_error(test_data)

    num_cores = multiprocessing.cpu_count()
    scores = Parallel(n_jobs=num_cores)(delayed(trainPeriod)(i) for i in range(1,len(data))
    t2 = datetime.now()-t1
    return ("The average MAE over all experiments is {}, time is {}".format(round((float(sum(scores))/len(scores)),2), t2))

When used in JupyterDash, I call the function like this

def update_output(n_clicks):
    if n_clicks is None:
        raise PreventUpdate
    return forestTrainer(amount)
GibbsConsulting commented 4 years ago

The multiprocessing module uses separate processes to do the calculation. I don't know to what extent this is affected by how you call it.

As a test, are you also able to measure the relative calc times if you dont use the muiltiprocessing Parallel feature?