ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
15.01k stars 855 forks source link

Adds device context manager #679

Closed dc-dc-dc closed 3 months ago

dc-dc-dc commented 3 months ago

Proposed changes

Utilize python contextlib to easily swap out what device is being used without having to keep track, updated fft tests to showcase effectiveness.

Checklist

Put an x in the boxes that apply.

awni commented 3 months ago

Very nice! I've been wanting to do something like this on the C++ side to simplify writing ops which take streams

dc-dc-dc commented 3 months ago

Yeah its pretty handy when you only have a device specific implementation. In c++ could probably achieve something similar with a lambda + macro

angeloskath commented 3 months ago

Love it. We have been putting off doing that for a while now.

If you plan on doing the same in C++ then I 'd suggest RAII instead of lambda. Basically a simple struct that keeps the old device and sets the desired device on construction and resets to the old device on destruction. Managing device scopes is literally done with C++ scopes after that. Also this should be able to take a stream instead of simply a device so we can use it in the ops or gradients to simplify things if we want to.

awni commented 3 months ago

💯 I was about to make the exact same comment re stream and RAII on the C++ side.

I don't want to overly complexify things, but I also wonder if we should do it on the C++ side and bind it in core instead of putting it in mlx.utils 🤔 . Seems sufficiently useful.

dc-dc-dc commented 3 months ago

Something like this ? (kindof annoying the signatures dont match as default_device returns a const ref)

struct StreamContextManager {
  public:
  StreamContextManager(StreamOrDevice s): _stream(default_stream(default_device())) {
    set_default_stream(to_stream(s));
  }

  ~StreamContextManager() {
    set_default_stream(_stream);
  }

  private:
    Stream _stream;
};
dc-dc-dc commented 3 months ago

Had to move StreamOrDevice into utils for it to compile, which makes sense as it is a utility function that could be used elsewhere outside of ops.

For the python binding, I'd still want to use contextlib to take advantage of function decorators and the with keyword. So I guess on enter create the struct and on exit just delete?

dc-dc-dc commented 3 months ago

Pushed some progress, though it seems like after calling set_default_device in c++ it doesn't seem to have an affect from python. I can confirm that the stream is cpu in context manager but the subsequent ops calls still use gpu. Will mess with this more, but just pushing in case this is something you guys ran into before

dc-dc-dc commented 3 months ago

nvm, fixing it now

dc-dc-dc commented 3 months ago

Updated and added docs, works nicely with the with keyword.

Couldn't get the decorator part working as I'd like, but it's not essential so will just leave it as is for now. If anyone is interested in adding decorator support, essentially just need to add __call__ def to the PyStreamContextManager class that takes in a py::function argument and return a wrapper function that calls enter func(*args, **kwargs) exit in that order.

awni commented 3 months ago

This is really nice.

My main comment is I find the name to be quite verbose: StreamContextManager.

Do you think we could/should put the enter and exit functions in the Stream class itself? Would be pretty nice to do e.g.:

with mx.Stream(mx.cpu):
   ...

Saves introducing a new name + it's a lot shorter..

dc-dc-dc commented 3 months ago

I believe the Stream constructor takes an index alongside the device, could just add a constructor that only takes device and auto creates the index if not specified. Would also need to create a wrapper around stream to hold the state of the context manager, not sure if this might conflict with any code.

awni commented 3 months ago

Yea right it's not so elegant on the implementation side. I don't love that, though I think it could certainly be done and without breaking any downstream code.

But let's at least consider alternative names. For example another option is lower case stream is the context manager:

with mx.stream(...):
   ....

Personally I like that a lot more than mx.StreamContextManager even though it introduces a bit of ambiguity between Stream and stream.

Any other suggestions?

dc-dc-dc commented 3 months ago

I think dropping the manager part and leaving it as StreamContext fits its purpose.

with mx.StreamContext(mx.cpu):
   ...
awni commented 3 months ago

I think dropping the manager part and leaving it as StreamContext fits its purpose.

I still like simple mx.stream more but that is certainly an improvement.

Just for some points of comparison:

dc-dc-dc commented 3 months ago

For streams torch uses a class called StreamContext they have a wrapper called stream that returns an instance of it

awni commented 3 months ago

For streams torch uses a class called StreamContext they have a wrapper called stream that returns an instance of it

Lol so they do both of our suggestions. What's the point of having both, wouldn't one just always use torch.stream?

dc-dc-dc commented 3 months ago

haha yeah, I guess its mainly to follow CapWords convention for class names. But included a shorter name function that wraps around it

dc-dc-dc commented 3 months ago

added the stream wrapper function so we support both options similar to torch

awni commented 3 months ago

Wonderful, thanks!! That's really nice!

Two more comments, then I think we are good:

awni commented 3 months ago

I pushed a few changes to the docs to fix some issues there.

Also I noticed that the following runs:

with mx.stream(None):
  pass

That should probably not run I think. Do you mind fixing that?