Sm00thix / IKPLS

Fast CPU and GPU Python implementations of Improved Kernel PLS by Dayal and MacGregor (1997) and Shortcutting Cross-Validation by Engstrøm (2024).
https://ikpls.readthedocs.io/en/latest/
Apache License 2.0
8 stars 3 forks source link

reproducibility and estimated measurements #24

Closed parmentelat closed 2 months ago

parmentelat commented 3 months ago

a final word about the paper's reproducibility

some values in the figure appear with a circle; about that the legend states that:

A square means that the experiment was run until the time per iteration had stabilized and used to forecast the time usage if the experiment was run to completion

can you please comment further about the practical means, if any, to achieve that, and about possible means to automate that process nevertheless ?

Sm00thix commented 3 months ago

Hi @parmentelat,

Sure! Some of the experiments with leave-one-out cross-validation (LOOCV) would take extremely long (up to several decades in one instance) to run to completion - especially the ones using the scikit-learn implementation.

For this reason, I opted to let the experiments run for a while and then, after some time, use the ratio of (time passed / cross validation splits completed) to forecast how much time would pass if I had let the entire experiment run to completion. This is a sensible thing to do as the splits are balanced. Thus, each split is expected to take the same amount of time to complete.

I waited until some multiple of n_jobs cross-validation splits had completed, noted the time, and manually forecasted the expected time to finish. The timings used in this process were from sklearn.model_selection.cross_validate internal timers that are printed when setting verbose=1.

Automation of this approach may be feasible for the regular CPU implementations by emptying some of cv=KFold() (currently defined on line 288 in timings.py.

Automation using the fast cross-validation implementation is likely not necessary at it is extremely fast - even for LOOCV with 1e6 samples. However, it can likely be supported by simply removing some multiple of n_jobs indices from cv_splits after line line 334 and before line 335 in timings.py.

Automation using the JAX implementations can probably be achieved in the same manner as for the fast cross-validation implementation by removing some of the indices in cv_splits after line 439 and before line 440 in timings.py.

These are my three ideas for automation of the estimation process. However, I have not actually tried it and it may be the case that I have missed some detail that makes automation more challenging.

Please let me know if you will accept the manual approach. Otherwise, I will try and implement my ideas as stated above.

parmentelat commented 3 months ago

hi @Sm00thix I'm also pinging @basileMarchand as he expressed interest in the matter

I must admit it would be very nice to be able to script the process of re-obtaining all the numbers depicted in the paper's main figure in this respect I believe your ideas above truly deserve a shot at least, since you seem to deem it feasible and even if full automation turns out to be out of reach, any progress that would make it more accessible for others to retrieve your figures for estimated runs would be welcome

regardless, once you are done with the improvements if any, I would recommend that you enrich paper/README.md to either give instructions on how to reproduce (if you improve the current situation), or at the very least how you have proceeded - i.e. the details above - if not

Sm00thix commented 2 months ago

Hi @basileMarchand and @parmentelat,

To accommodate your request to automate the benchmark estimation process, I have attempted to implement the ideas I described in my previous comment in this thread.

I was successful for the regular CPU implementations. These are scikit-learn's NIPALS and my own NumPy-based IKPLS implementations - i.e., the ones that can be benchmarked with time_pls.py with the flags -model sk, -model np1, and -model np2, respectively. They can now be estimated by adding the --estimate flag to the call to time_pls.py in the benchmark branch of the IKPLS repo.

For the JAX implementations and fast cross-validation implementation, I was unable to implement automation of the estimation process. Instead I updated the paper/README.md to clarify the manual approach as suggested by @parmentelat.

These changes are visible in https://github.com/Sm00thix/IKPLS/commit/841dc4aa59e6ac38a5971a4e493f595e2d387f32.

If you agree with these changes, I will merge them into main. Alternatively, if you think it is confusing to have automation only available for a subset of the implementations, I will remove the possibility to automate the regular CPU implementations and instead only clarify the manual approach in paper/README.md.

Please let me know which option you prefer. I will then proceed as you wish and close the issue.

Below are my explanations for why I was unable to automate the benchmark estimation of the JAX implementations and the fast cross-validation implementation:

I was not successful for the JAX implementations. The reason being the way I implemented the cross-validation where simply removing elements from cv_splits removes them from both training and validation splits, effectively just making n smaller. I do not think I can make this work. At the very least I would have to rewrite cross_validate (and possibly related methods) in jax_ikpls_base.py. I sincerely do not think I should modify the implementations with the sole goal of making automation of benchmark estimation possible.

For the fast cross-validation algorithm, timing the execution of a small number of cross-validation splits, computing the time/split ratio, and then linearly forecasting a time estimate for computing the whole cross-validation fails. The reason for this is to be found in the nature of the fast cross-validation algorithm: It initially performs one relatively expensive computation and then performs relatively cheap operations during iteration over the folds. This initial one-time expensive computation implies that the suggested approach to estimation will fail to account for the fact that the initial expensive computation needs only to be performed once. Details about the algorithm can be found in this paper. In practice, estimation of the fast cross-validation algorithm seems somewhat unnecessary as it is, as the name implies, fast. Even for the somewhat extreme case of leave-one-out cross-validation with a million samples, 500 features, 10 targets and 30 PLS components, it only takes 20 minutes to complete. I also did not estimate its runtime in my experiments - as evidenced by the plots in paper/timings/timings.png.

parmentelat commented 2 months ago

hi @Sm00thix

thanks for your work in this area, I really believe this kind of apparently unrewarding task contributes to a much better material for others to toy with :)

My advise would be to

thanks again for bearing with me on this one :)

Sm00thix commented 2 months ago

Hi @parmentelat,

Thanks for your assistance on this one. I agree that we have made it easier for others to toy around!

I have merged your pull request to the benchmark branch and, in turn, merged the benchmark branch with main. I have also tried to improve the very last paragraph of paper/README.md as per your instructions.

All the changes are merged to main in https://github.com/Sm00thix/IKPLS/commit/ee3de3e51c254693b4c53e12670b4383795c60c5

If you agree with these changes, I think we can close this issue :-)

parmentelat commented 2 months ago

hey @Sm00thix

actually I am currently trying to recompute all the data from the figure this is something that I'll put on back burner, hopefully just waiting for the job to complete for now I have secured a GPU-capable container to that end

in the process I have run into 2 separate issues

I'll get this repro script to work as far as I can, and will file a PR once I'm done; please keep this open until then

parmentelat commented 2 months ago

on a side track, I am seeing this during my reproducibility attempts, with jax-related runs feel free to open a separate issue if need be


# 13/282: (jax2 x 30 x 1 x 100000 - False) - expect 4.27
python3 time_pls.py -o timings/user_timings.csv -model jax2 -n_components 30 -n_splits 1 -n 100000 -k500 -m 1 -n_jobs -1 
Fitting JAX Improved Kernel PLS Algorithm #2 with 30 components on 100000 samples with 500 features and 1 targets.
/home/ubuntu/miniconda3/lib/python3.12/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.
  warnings.warn(
/home/ubuntu/miniconda3/lib/python3.12/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.
  warnings.warn(
Time: 4.175951265729964```
Sm00thix commented 2 months ago

Hi @parmentelat,

In regards to your benchmarks: Sounds good. Depending on your machine, running all the benchmarks may take a few weeks if I recall correctly. That is, if you estimate the same ones I did and run to completion the same ones I did.

I'm sorry if I broke your notebook with some of my own changes. Be aware that a NaN in the 'inferred' column in `paper/timings/timings.csv' is to be regarded as False. That is, I (admittedly somewhat foolishly), did not write anything in that column if I did not estimate the value. Doing boolean logic with NaN's is error prone and I tried to account for that when I modified your script. Please let me know if I can be of any help in this regard. I will keep this issue open until that is resolved.

In relation to the FutureWarning that you mention in your comment, I found the culprit and fixed it. I commented on the details in https://github.com/Sm00thix/IKPLS/commit/9cf6cc7eaa4b840676a5cc366bb10feee6b128d8. I also ran a few of the JAX benchmarks and noticed no difference when compared to my original benchmarks.

I initially forgot to bump the ikpls version number. I did this in https://github.com/Sm00thix/IKPLS/commit/fc694c035c12ac7c8cfd6f63e08a172a4c5fe1a7.

All of these changes are currently in the dev branch which I will merge to main once all tests pass.

Alright. I had made a couple of errors which caused the computation of gradients using reverse mode differentiation to fail. I fixed those in https://github.com/Sm00thix/IKPLS/commit/97c902472fec7ede5dfe9b9bd128fc965e730d2a.

Update

I merged the dev branch to main.

parmentelat commented 2 months ago

the paper, at least timings.csv, holds 606 measurement points; right now I have collected in the 360's I might not go far enough for it to spend weeks, but I'll try to maximize the number of data points that I can gather I will file a PR with the latest changes that I made in the notebook now that this script works fine for me

no worries about the notebook, the first version was very rough and needed ironing anyways; plus your code was actually working, so...

also for clarity, none of the warnings that I reported are to be deemed showstoppers, it's just FYI in case you'd have missed them

Sm00thix commented 2 months ago

Hi @parmentelat,

I merged your PR in https://github.com/Sm00thix/IKPLS/commit/b27201b96477e66fab2f1932141374dcdc3e7719. Thanks for your work on this one! Do you want to wait until your benchmarks have completed before we close this issue? Otherwise, I suggest we close it now :-)

parmentelat commented 2 months ago

ok for closing now