oxideai / mlx-rs

Unofficial Rust bindings to Apple's mlx framework
https://oxideai.github.io/mlx-rs/
Apache License 2.0
51 stars 4 forks source link

Initial attempt at implementing `mlx-nn` NN modules #100

Open minghuaw opened 1 month ago

minghuaw commented 1 month ago

This is my initial attempt at implementing the neural net modules and optimizers. The following are included in this PR

minghuaw commented 1 month ago

This is just an initial attempt at implementing some of the neural net components. I think some feedback on the overall API design and ergonomics would be nice @dcvz

minghuaw commented 1 month ago

Ended up changing Module trait to take &Array. The only cost (so far) is that an emptySequential would end up deep_clone the input array, but this should make overall usage more flexible as most (if not all) ops only require a ref input

minghuaw commented 1 week ago

@dcvz This is still missing a bunch of docs and tests, but I think we could start reviewing some key impls. Plus, I think this is probably a good point to stop otherwise it might be too much for a single PR. The remaining NN layers/modules should be added in separate PRs.

  1. mlx_nn::value_and_grad::value_and_grad() function. This is where I had the discussion about whether model.update(...) in the swift binding is necessary, and it turns out that we need this too. Essentially, if we don't update the model's parameters to the ones passed in, the gradients will be all zeros. This, however, means the closure returned by value_and_grad requires a mutable reference to the model, which seems a bit counter intuitive (why would computing the value and gradients require a mutable borrow?). So I guess the question is whether there's a better way to handle this or should we just clarify in the docs?
  2. Traits and macros related to Module. The Module trait is separated into a new crate mlx-nn-module so that it would be easier to handle the dep problem in the the derive macro impl. We need to use this macro in mlx-nn but the users should be able to use this macro in their own crates as well.
  3. There's a couple changes to the base mlx-rs crate. These should all be what's necessary to make mlx-nn to work.
minghuaw commented 6 days ago

Added remaining docs. I'm wondering if we should add equations to the docs? If so, how? Latex equation doesn't seems to be supported. Or should we just link to the python or swift binding docs?