sbrunk / storch

GPU accelerated deep learning and numeric computing for Scala 3.
https://storch.dev
Apache License 2.0
118 stars 7 forks source link

Add popular transformer implementations #44

Open sbrunk opened 1 year ago

sbrunk commented 1 year ago

It would be great to add a few transformer architectures. This will also help us to prioritize which op and module implementations to add next.

It probably makes sense to port models from Hugging Face, so we can load/convert weights from their hub.

It also might make sense to create an extra sub-project for transformer models, where we can also put shared modules/helpers.

hmf commented 1 year ago

Would a "pico" example from Karpathy's "Let's build GPT: from scratch, in code, spelled out" be of any interest?

Later we could also "port" his nano GPT implementation. Is this also of interest?

sbrunk commented 1 year ago

Would a "pico" example from Karpathy's "Let's build GPT: from scratch, in code, spelled out" be of any interest?

Later we could also "port" his nano GPT implementation. Is this also of interest?

@hmf yes absolutely (both)! Great idea! If you're interested to give it a try, please create an issue (or directly a PR). If there's anything missing module/op-wise, we can add that along the way.

emartinezs44 commented 1 year ago

It may be a ver very very large task. The huggingface library has become hugh. I´m pretty sure that the most used models in production are Bert-related(Bert, Roberta) and Vit. And maybe those could be the starting point.

Pain points that I see:

  1. Tokenizers: there are some in Python and others are Rust based. And because tokenizers are closely linked to the models there must be some solution for that(via sbt tasks??)
  2. Pipelines: are you planing to create a similar pipelining strategy from scratch?. I managed to create an implementation implementation with BigDL dllib to use tokenizer and fine tuned models as Spark pipelines. But it requires Spark and only in multi-CPU mode. Besides, there are several drawbacks due to the underlying framework that should be fixed. It consumes a lot of heap due the Jvm memory handling.

I assume that Scala can provide a more type safe way of building graphs, (I don´t see the deep learning community very worried about type safety) but before jumping to the deep with this there must be some strategy to attract the Pytorch community, and personally I see a lot of problems here due to the poor numerical and visualization ecosystem in Scala... among other things. Other projects like Tensorflow for Swift failed eventhough there were sponsored by Google, so I think there is necessary to think about the added value(I´m sure that there is) of starting such a big task. Good job by the way!

sbrunk commented 1 year ago

Thanks for your inputs @emartinezs44.

It may be a ver very very large task. The huggingface library has become hugh. I´m pretty sure that the most used models in production are Bert-related(Bert, Roberta) and Vit. And maybe those could be the starting point.

It will take some effort, but I'm convinced it's totally doable to implement the most popular architectures. Huggingface doesn't have shared model components, but I think for us it makes more sense to share common components perhaps similar to what the explosion.ai folks are doing with curated-transformers.

Pain points that I see:

  1. Tokenizers: there are some in Python and others are Rust based. And because tokenizers are closely linked to the models there must be some solution for that(via sbt tasks??)

I'v built a working solution in the form of a Scala API for the Rust based fast-tokenizers from Hugginface. I think that should cover tokenization for many transformer models.

  1. Pipelines: are you planing to create a similar pipelining strategy from scratch?. I managed to create an implementation implementation with BigDL dllib to use tokenizer and fine tuned models as Spark pipelines. But it requires Spark and only in multi-CPU mode. Besides, there are several drawbacks due to the underlying framework that should be fixed. It consumes a lot of heap due the Jvm memory handling.

That's interesting, thanks for sharing. So far I have only focused on the core modeling and training part, using a single machine and a single GPU. I'm not sure yet what should be part of the library itself for scaling things up. Perhaps it's better/more flexible to have another solution that uses Storch as part of a pipeline or/and in a distributed setting. Given your experience, I'd be interested to hear your thoughts.

I assume that Scala can provide a more type safe way of building graphs, (I don´t see the deep learning community very worried about type safety) but before jumping to the deep with this there must be some strategy to attract the Pytorch community, and personally I see a lot of problems here due to the poor numerical and visualization ecosystem in Scala... among other things. Other projects like Tensorflow for Swift failed eventhough there were sponsored by Google, so I think there is necessary to think about the added value(I´m sure that there is) of starting such a big task. Good job by the way!

Thanks. I'm trying to go with a very pragmatic approach here. Build something I think is fun and useful, show it to people within the Scala community first to see if there's interest (hence the Scaladays talk), listen to feedback, and then continue in the most promising direction. And then of course it also depends a lot on what people contribute and how much. So let's see how it's going to play out. :)

I totally agree with you that we should to improve the visualization situation. If we're able to have on well working, maintained Scala viz library with notebook support, it would already be a big step forward.

hmf commented 1 year ago

@sbrunk

I totally agree with you that we should to improve the visualization situation. If we're able to have on well working, maintained Scala viz library with notebook support, it would already be a big step forward.

Just a heads up, although I am not a fan of JS plotting libraries (Scala should have its own, some Java libraries exist), notebook support with visualization can be found here:

  1. Almond Scala kernel for Jupyter
  2. The (Scala) polyglot notebook

There are other Scala plotting libraries but they are lacking or dead.