facebookresearch / shumai

Fast Differentiable Tensor Library in JavaScript and TypeScript with Bun + Flashlight
https://facebookresearch.github.io/shumai
MIT License
1.13k stars 26 forks source link

Examples about training and inference #128

Open fabiospampinato opened 1 year ago

fabiospampinato commented 1 year ago

I'd like to build a little feed-forward fully connected thing with just one hidden layer, I looked at the examples but perhaps the most relevant one, train.ts, doesn't seem to work anymore as things like sm.module.sequential and sm.optim.Adam don't seem to exist anymore.

It would be great to get that example fixed.

In general it would also be great to get a sort of simpler and more exhaustive "getting started" example, like a tiny model that learns XOR that showcases how to build the network (perhaps with 1 hidden layer for the sake of showcasing how to do it), how to feed it data for training, and how to validate it with more data afterwards.

At the moment I'm a bit stuck, I have the dataset, I had the network sort of working on top of Brain.js (too slow), but I don't know what Shumai code I should write to recreate the same network and training/testing "pipeline".

bwasti commented 1 year ago

whoops, we didn't publish the latest code with sequential and Adam included. Version 0.0.11 should be good now!

Adam is unfortunately a bit buggy at the moment so I changed the example to use sgd

fabiospampinato commented 1 year ago

Awesome, it seems to do something now, thank you.

Some more questions if you don't mind:

  1. Would it be possible to fix the types too? Like ideally these examples should type-check without issues, and currently there seem to be some type errors at the library level, like TS doesn't seem to know about sm.module.sequential, or for example sm.module.linear says it accepts inp_dim: any, out_dim: any, rather than numbers, which doesn't help with trying to figure things out.
  2. Is inference performed basically just by calling the model with the input like in the example: m(x)?
  3. Is there a way to access the weights for all the layers? I'd like to compile the trained model to a standalone JS function so that I can use it everywhere easily.
  4. Is the decreasing fractional number that I see in the console basically the current error percentage of the model? If so is there an example on how to get this number and logging things manually? That viter function I'm not sure what is doing.
  5. It would be very useful to have like an example XOR model, where perhaps as few ~internal functions are used as possible (like using sm.util.range in const i of sm.util.range(3) seems unnecessary) and where there are some, even minimal, comments about what these functions do (maybe most of these are "obvious", but like no other ML framework seems to have a function named "viter", and I can't find anything for "ema loss" either). Can something like this be added?

Thanks.