Closed dc-dc-dc closed 3 months ago
Very nice! I've been wanting to do something like this on the C++ side to simplify writing ops which take streams
Yeah its pretty handy when you only have a device specific implementation. In c++ could probably achieve something similar with a lambda + macro
Love it. We have been putting off doing that for a while now.
If you plan on doing the same in C++ then I 'd suggest RAII instead of lambda. Basically a simple struct that keeps the old device and sets the desired device on construction and resets to the old device on destruction. Managing device scopes is literally done with C++ scopes after that. Also this should be able to take a stream instead of simply a device so we can use it in the ops or gradients to simplify things if we want to.
💯 I was about to make the exact same comment re stream and RAII on the C++ side.
I don't want to overly complexify things, but I also wonder if we should do it on the C++ side and bind it in core instead of putting it in mlx.utils
🤔 . Seems sufficiently useful.
Something like this ? (kindof annoying the signatures dont match as default_device returns a const ref)
struct StreamContextManager {
public:
StreamContextManager(StreamOrDevice s): _stream(default_stream(default_device())) {
set_default_stream(to_stream(s));
}
~StreamContextManager() {
set_default_stream(_stream);
}
private:
Stream _stream;
};
Had to move StreamOrDevice
into utils for it to compile, which makes sense as it is a utility function that could be used elsewhere outside of ops.
For the python binding, I'd still want to use contextlib to take advantage of function decorators and the with keyword. So I guess on enter create the struct and on exit just delete?
Pushed some progress, though it seems like after calling set_default_device
in c++ it doesn't seem to have an affect from python. I can confirm that the stream is cpu in context manager but the subsequent ops calls still use gpu. Will mess with this more, but just pushing in case this is something you guys ran into before
nvm, fixing it now
Updated and added docs, works nicely with the with
keyword.
Couldn't get the decorator part working as I'd like, but it's not essential so will just leave it as is for now. If anyone is interested in adding decorator support, essentially just need to add __call__
def to the PyStreamContextManager
class that takes in a py::function
argument and return a wrapper function that calls enter
func(*args, **kwargs)
exit
in that order.
This is really nice.
My main comment is I find the name to be quite verbose: StreamContextManager
.
Do you think we could/should put the enter
and exit
functions in the Stream
class itself? Would be pretty nice to do e.g.:
with mx.Stream(mx.cpu):
...
Saves introducing a new name + it's a lot shorter..
I believe the Stream constructor takes an index alongside the device, could just add a constructor that only takes device and auto creates the index if not specified. Would also need to create a wrapper around stream to hold the state of the context manager, not sure if this might conflict with any code.
Yea right it's not so elegant on the implementation side. I don't love that, though I think it could certainly be done and without breaking any downstream code.
But let's at least consider alternative names. For example another option is lower case stream is the context manager:
with mx.stream(...):
....
Personally I like that a lot more than mx.StreamContextManager
even though it introduces a bit of ambiguity between Stream
and stream
.
Any other suggestions?
I think dropping the manager part and leaving it as StreamContext
fits its purpose.
with mx.StreamContext(mx.cpu):
...
I think dropping the manager part and leaving it as
StreamContext
fits its purpose.
I still like simple mx.stream
more but that is certainly an improvement.
Just for some points of comparison:
torch.cuda.device(..)
as the context managerjax.default_device(..)
as the context managerFor streams torch uses a class called StreamContext they have a wrapper called stream that returns an instance of it
For streams torch uses a class called StreamContext they have a wrapper called stream that returns an instance of it
Lol so they do both of our suggestions. What's the point of having both, wouldn't one just always use torch.stream
?
haha yeah, I guess its mainly to follow CapWords convention for class names. But included a shorter name function that wraps around it
added the stream
wrapper function so we support both options similar to torch
Wonderful, thanks!! That's really nice!
Two more comments, then I think we are good:
StreamContext
as well? Seems better to be consistent there.stream
wrapper. I think it would be good to have that as the preferred option.I pushed a few changes to the docs to fix some issues there.
Also I noticed that the following runs:
with mx.stream(None):
pass
That should probably not run I think. Do you mind fixing that?
Proposed changes
Utilize python contextlib to easily swap out what device is being used without having to keep track, updated fft tests to showcase effectiveness.
Checklist
Put an
x
in the boxes that apply.pre-commit run --all-files
to format my code / installed pre-commit prior to committing changes