Open josephrocca opened 2 years ago
Thanks for the feedback!
I've heard of JAX and I've used TensorFlow.js and TensorFlow in Python. They're fine and offer lots of tools, but as far as I understand, the types of models that can be loaded AND trained in TensorFlow.js are limited. I'm looking at their API docs now and maybe there's more support now, I investigated this a lot about a year ago (July 2021) so maybe my information is outdated. If you live in a TensorFlow flow world, then maybe it's fine, but I'm living in a PyTorch world.
Context: I'm passionate about Federated Learning and I think the best way to bring models to many users is through browsers and webpages. So, enabling training in JavaScript would really help people share models. I've seen many other small projects for training in the browser, TensorFlow.js is obviously the most mature, but most of my colleagues live in a PyTorch world and we don't want to rewrite our models to work in some other framework. I've looked a lot for ways to convert PyTorch models to be trainable in TensorFlow.js but there were a few gotchas and it seems like it's not possible in general. Discussed here. I want to work with tools provided by OpenMined such as PySyft. Those tools don't have a way anymore to train in JavaScript, they mainly support Python now. I did this as an example to motivate training in JavaScript so that other projects can support it or so that ONNX Runtime (ORT) Web could be motivated to improve support for it. Until then, I might copy this example for a toy project, but I don't have any specific plans yet.
I didn't do much serious profiling before. I just did some: Batch size: 64 Windows 11, Intel(R) Core(TM) i7-1065G7 CPU @ 1.30GHz 1.50 GHz Average batch time for this in my browser: about 500ms Average batch time for PyTorch training on CPU in WSL2: about 10ms So the browser was about 50X slower.
Thanks for the tips on how to speed this up. Should I just copy that script into my repo? BTW It seems to be about the same speed when hosted locally vs. on GitHub. I've been wondering how I could make it faster and I'm not sure if I'm using ORT Web properly. Also, the browser UI becomes unusable when training. Any tips to avoid that?
I only heard of pyscript once I finally got this working. It would pretty much make this obsolete and make training much easier. I hope they get PyTorch working in pyscript soon 🤞🏻 but from what I understand, even if it "worked", it might not be as fast or ORT Web until it can use WASM/GPUs.
Let's keep chatting here as I think this is an important space for people to follow.
Great to hear people like yourself are taking browser-based ML training seriously! Like you said, this is the only way to motivate frameworks (like ORT Web) and browsers to improve support, and to open the doors to other developers when they can see what's possible with demos like this. Someone has to take the first step.
I've created #9 to prevent the UI from becoming unusable and to add the Service Worker which proxies all requests to add the COEP/COOP headers which trigger cross-origin isolation to allow use of wasm threads. Note that the service worker script will cause the page to refresh when it's first loaded, but this is not really noticeable, and is fine for a demo like this. The script will do nothing if the page was already served with COEP/COOP headers.
I hope they get PyTorch working in pyscript soon
Me too! PyScript is actually more of a helper library/framework that sits on top of Pyodide, which is the actual browser-based Python runtime (it's a wasm port of CPython). Here's the list of currently-ported Python packages in case you're curious. That list doesn't include pure-Python packages, since they already work in Pyodide - it's only the ones with C++/Rust/etc. code that need to be "ported" with Emscripten to wasm. The heavy-compute parts of PyTorch are obviously all in non-Python, so they'll be running in wasm, but you're right that it still might be slower than ORT Web simply because ORT Web has Microsoft behind it, and has fewer moving parts and a narrower scope, so it's likely easier to optimise. Still super exciting either way. Regarding GPU usage in Pyodide - currently, IIUC, Emscripten has some OpenGL-to-WebGL conversion support, but I doubt that'll be useful in the case of porting PyTorch. Hoping that WebGPU support in the major browsers (which isn't far away) will make it easier to get native GPU code ported to the web.
Curious to see how the performance changed with the wasm threads enabled. Also wondering if you tried the WebGL backend or ORT Web? Or did it not have the required op support? I wouldn't be surprised - the WebGL backend is almost always missing op support for the models that I try to run. Hoping that with WebGPU the ORT team will be able to directly port their native GPU kernels like they do with the wasm-ported CPU ones, which would mean good GPU op coverage "for free" (well, for the cost of setting up the conversion pipeline).
Nice work on this! I've been thinking of doing something similar with JAX with either tfjs-tflite or ORT Web.
I'm curious if you have any plans to take this further, or just wanted to get a demo/POC working for fun?
Also wondering whether you've tried benchmarking against Pytorch's CPU backend? Curious how much slower it is. You could probably get a decent speedup in your Github Pages demo by including this script in the
<head>
- it's quite hacky, but AFAIK it's the only way to get cross-origin isolation working on Github Pages since there's no feature to set COOP/COEP headers yet.Regarding JAX, it seems like it should be relatively easy at this point (given recent model conversion tooling improvements) to do a similar thing as you've done here, but I've been thinking that it'd be cool to just get the whole JAX lib (CPU backend) working in the browser, so even the model-building side of things can be done in the browser. I don't have much experience with C/C++, and so the build process is kind of a black box to me, and currently stuck here (more info) in case this is interesting to you. JAX is quite "lean" in terms of dependencies and stuff, so it seems like it shouldn't be toooo hard. Will try to keep chipping away at this in my spare time.
There is apparently some work going on to get PyTorch working in the browser, which would be quite a feat - seems like there are a lot more dependencies/"moving parts" in PyTorch compared to JAX.