Closed Redmept1on closed 6 months ago
You can assign this to me, I have a good implementation in mind.
mx.where
is implemented using math on other primitives, but inf values don’t support arithmetic so you get NaN values.
Correctly handling inf values here also fixes #576, because mx.logsumexp
is implemented using mx.where
.
What do you have in mind @Rifur13 ?
A new Where
primitive for conditional selection. I don’t see a way of fixing this with existing primitives. I’ll benchmark it on cpu/gpu - it should also be more performant.
Right exactly.. I was wondering if you had some idea other than a new primitive, but that is the only way I've thought of so far 😄
If you are up for implementing it that would be great. You can see how our binary op primitives work as a good starting point. It migth 🤔 be worth doing something similar but ternary instead.
@Rifur13, I am also interested in this. Happy to help if you need
Thanks! I'm almost done but we can iterate on a design in my PR if needed
@Rifur13 are you still working on this one?
Yep, let me clean up it up a bit and I'll send out the PR.
Describe the bug Wrong output was obtained when x and y both have inf in mlx.coer.where. Whether it's from x or y, it should get inf instead of nan.
To Reproduce
Expected behavior array([inf, 20, 3, 40], dtype=float32)
Desktop (please complete the following information):