Open StephenChan opened 7 months ago
First of all, there are the low-level deep learning frameworks (PyTorch, Tensorflow, Google JAX), and then there are libraries which define more approachable interfaces on top of those frameworks (Keras, PyTorch Lightning, fastai). My general impression is that you want one of the more approachable interfaces unless you're doing heavy customization/research, which we aren't yet. But, the framework may or may not influence the choice of library.
So far, I don't know which framework is "better" overall, but quick notes on them:
Tensorflow is the oldest, then PyTorch, then JAX which is the newest.
Tensorflow was owned by Google, but it's FOSS under Apache license as of 2015; their engineers may still have a big part in development though. PyTorch was owned by Meta, but it's FOSS under modified BSD license as of 2022, and is now part of the Linux Foundation umbrella; their governing board apparently has reps from various corporations - Amazon, Google, Microsoft, Meta etc. JAX is owned by Google.
PySpacer already uses PyTorch for EfficientNet feature extraction, and PyTorch was also used for creating coralnet 1.0's EfficientNet feature extractor weights (blog post; utility functions are still in spacer/models
of the pyspacer repo). Now, up to this point, I haven't delved enough into those details to be personally comfortable with PyTorch usage at all (that was all Oscar and Qimin). However, it's certainly nice to reuse a framework that's already in the project; the torch and torchvision packages do take large amounts of time and space to install in the Docker build, and presumably that'd be doubled if we were to install Tensorflow / JAX on top of that.
My notes on libraries so far - very preliminary stuff, partly based on word on the street at Stack Overflow / Reddit / etc., and obviously there are more libraries out there:
Keras: Part of Tensorflow's org, but released support for PyTorch and JAX very very recently (version 3.0, 1-2 weeks ago!).
PyTorch Lightning: Apparently the high-level-ness is comparable to Keras; that is, seemingly balanced enough to work for most folks. First public release was 2019.
PyTorch Ignite: Part of PyTorch's GitHub org. I'm uncertain on its maturity as a library, as it's currently on version "0.4.13", but the first public release was 2018 which is not all that young.
fastai: PyTorch based, and inspired by Keras. Seems to be one of the most opinionated library choices, easy to get going and does use best practices, but might be painful to do even moderate customizations. Their quick-start guide starts with from X import *
statements which is not great practice in Python. I'm particularly unimpressed by their landing page, with the feed of clickbait-headline posts, and the "Make again" header.
FWIW: I've used Pytorch lighting and it's good choice if we go down the Pytorch route.... You could probably code this up on 4-8 hours for sigle threaded CPU training. Optimizing on GPU isn't that hard really, although there are some hyperparameters that affect performance.
David
On Thu, Dec 7, 2023 at 3:41 PM StephenChan @.***> wrote:
Potential libraries for GPU training
First of all, there are the low-level deep learning frameworks (PyTorch, Tensorflow, Google JAX), and then there are libraries which define more approachable interfaces on top of those frameworks (Keras, PyTorch Lightning, fastai). My general impression is that you want one of the more approachable interfaces unless you're doing heavy customization/research, which we aren't yet. But, the framework may or may not influence the choice of library.
So far, I don't know which framework is "better" overall, but quick notes on them:
-
Tensorflow is the oldest, then PyTorch, then JAX which is the newest.
Tensorflow was owned by Google, but it's FOSS under Apache license as of 2015; their engineers may still have a big part in development though. PyTorch was owned by Meta, but it's FOSS under modified BSD license as of 2022, and is now part of the Linux Foundation umbrella; their governing board apparently has reps from various corporations https://urldefense.com/v3/__https://pytorch.org/foundation__;!!Mih3wA!D7ZO3tozwHfKIi617O4WOVH8YIQJoNkmAOSTp8VUmMSMKyyJUmFW0z3UU5IFagV18zGOQ0dhGAfd_IeqPt_vXFL-$
- Amazon, Google, Microsoft, Meta etc. JAX is owned by Google.
My notes on libraries so far - very preliminary stuff, partly based on word on the street at Stack Overflow / Reddit / etc., and obviously there are more libraries out there:
-
Keras: Originally Tensorflow based, but released support for PyTorch and JAX very very recently (version 3.0, 1-2 weeks ago!).
PyTorch Lightning: Apparently the high-level-ness is comparable to Keras; that is, seemingly balanced enough to work for most folks. First public release was 2019.
PyTorch Ignite: Part of PyTorch's GitHub org. I'm uncertain on its maturity as a library, as it's currently on version "0.4.13", but the first public release was 2018 which is not all that young.
fastai: PyTorch based, and inspired by Keras. Seems to be one of the most opinionated library choices, easy to get going and does use best practices, but might be painful to do even moderate customizations. Their quick-start guide starts with from X import * statements which is not great practice in Python. I'm particularly unimpressed by their landing page, with the feed of clickbait-headline posts, and the "Make again" header.
— Reply to this email directly, view it on GitHub https://urldefense.com/v3/__https://github.com/coralnet/pyspacer/issues/74*issuecomment-1846271077__;Iw!!Mih3wA!D7ZO3tozwHfKIi617O4WOVH8YIQJoNkmAOSTp8VUmMSMKyyJUmFW0z3UU5IFagV18zGOQ0dhGAfd_IeqPisAZh4a$, or unsubscribe https://urldefense.com/v3/__https://github.com/notifications/unsubscribe-auth/ABKA5AMOZHYUDAXAJ2ZXXBDYIJHZDAVCNFSM6AAAAABALZXSBCVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTQNBWGI3TCMBXG4__;!!Mih3wA!D7ZO3tozwHfKIi617O4WOVH8YIQJoNkmAOSTp8VUmMSMKyyJUmFW0z3UU5IFagV18zGOQ0dhGAfd_IeqPvNSeiIb$ . You are receiving this because you are subscribed to this thread.Message ID: @.***>
Good to know, as from what I've seen so far, I was also slightly leaning towards PyTorch Lightning. So I guess the migration from scikit-learn's pickled classifier format could really be the part that takes the most work.
+1 for Lightning. I think this might also be a decent case for setting up some kind of mlops for experimentation too.
Here's a quick way to add the gpu as an option to torch_utils.py
https://github.com/yeelauren/pyspacer/blob/train_coralnet/spacer/torch_utils.py
Cool! Thanks, good to know what torch's CPU/GPU switch would look like. Yeah, definitely some organized experiments would be worthwhile for this issue.
Colleague shared this with me which might be an option for sklearn : https://docs.rapids.ai/api/cuml/stable/
Ah interesting - offhand I can't tell if cuML can do MLP, but might be an option if we end up being flexible on the training algorithm.
I've also just heard about skorch but I'm also not sure what it could do for us yet.
A couple of thoughts on this.
If we go to deep learning, I'd also vote for PyTorch with some wrapper, like lightning.
However, I don't think it makes much sense to move to GPU based training for LR or 1-2 layer hidden MLPs. I think there is a lot we can do to set convergence criterion, thread data loading while computing gradients, and so on.
So the reason to go to GPUs, IMO, would be if we move to deep fine-tuning of the base networks. There is no doubt that would lead to more accurate classifiers. In particular for larger sources.
The issue/complexity with deep fine-tuning is on the invoke (predict) step. The fine-tuned nets will be the same size (in bytes) as the original network. That means that there is lot more data to load up when the predict request arrives. Which will affect invoke latency.
Anyways, my 2 cents.
Thanks for chiming in, Oscar - yeah, I can at least confirm there is much more we can look at for optimizing classifier training, before moving to GPU (I'm currently looking into the data loading). As we do so, we should keep benchmarking and looking for where the bottleneck is. In general I don't know which algorithms (like MLP with our params) benefit the most from GPU, so it's good to have more experienced insight on that.
I'm definitely not familiar enough with the network-tuning process to fully understand that part, but will keep your notes in mind!
I'm planning to work on a pytorch implementation for the MLP in the near future. I believe it will be beneficial for MERMAID and CoralNet. Additionally, I'm not satisfied with the limitations of sklearn regarding #98. Therefore, I think it would be better to work towards a pytorch implementation for the long-term goal or at the very least for the classification.
There's very much a question of whether this is worth the leap for us in the short term, but either way it's worth documenting what this might involve.
First of all, we use scikit-learn for running training, whether it's MLP (multilayer perceptron) or LR (logistic regression) / SGD (stochastic gradient descent) training.
Second, scikit-learn is CPU-only and their MLP is apparently on the simple side:
scikit-learn FAQ:
Note that MLP = neural network.
scikit-learn neural networks page:
A Stack Overflow Q&A:
Follow-up comment and reply:
So, GPU is relevant for MLP training generally, and scikit-learn can't do GPU.
Next there's the question of how much work it'd take to support GPU-training in pyspacer:
Need to pick a different ML library which supports GPU training, and need to learn how to install and use that library.
The actual training calls happen in
train()
of train_utils.py, a function that's about 70 lines long (although somewhat dense). So the scikit-learn calls and logic in that function need to be replaced with calls/logic of a different ML library. I'm sure the inputs also need to be in a slightly different format, so there'll be some data manipulation work to be done.classify_features needs to be reworked to accept classifiers created by the different ML library. The classifier format we currently use is a pickle of CalibratedClassifierCV, a scikit-learn class, so the current format's scikit-learn-specific. (Side note: it would be great to move away from library-specific classifier formats, regardless of whether that move is done as part of a move towards GPU or not)
For compatibility with existing coralnet classifiers, as well as for experimentation purposes, training should support: scikit-learn CPU training, non-scikit-learn GPU training, and probably also non-scikit-learn CPU training. Classification should support scikit-learn classifiers and non-scikit-learn classifiers (ideally a common interchange format for better future-proofing).
I'll make a follow-up comment on potential ML libraries for GPU training.