ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
16.23k stars 925 forks source link

Wrong output was obtained when x and y both have inf in mlx.coer.where #598

Closed Redmept1on closed 6 months ago

Redmept1on commented 6 months ago

Describe the bug Wrong output was obtained when x and y both have inf in mlx.coer.where. Whether it's from x or y, it should get inf instead of nan.

To Reproduce

import mlx.core as mx
import numpy as np
condition_mx = mx.array([True, False, True, False], dtype=mx.bool_)
x_mx = mx.array([float('inf'), 2, 3, 4], dtype=mx.float32)
y_mx = mx.array([float('inf'), 20, 30, 40], dtype=mx.float32)

condition_np = np.array([True, False, True, False])
x_np = np.array([float('inf'), 2, 3, 4], dtype=np.float32)
y_np = np.array([float('inf'), 20, 30, 40], dtype=np.float32)

result_mlx = mx.where(condition_mx, x_mx, y_mx)
print("Result on Mlx:", result_mlx)

result_np = np.where(condition_np, x_np, y_np)
print("Result using NumPy:", result_np)

image

Expected behavior array([inf, 20, 3, 40], dtype=float32)

Desktop (please complete the following information):

Rifur13 commented 6 months ago

You can assign this to me, I have a good implementation in mind.



mx.where is implemented using math on other primitives, but inf values don’t support arithmetic so you get NaN values. Correctly handling inf values here also fixes #576, because mx.logsumexp is implemented using mx.where.

awni commented 6 months ago

What do you have in mind @Rifur13 ?

Rifur13 commented 6 months ago

A new Where primitive for conditional selection. I don’t see a way of fixing this with existing primitives. I’ll benchmark it on cpu/gpu - it should also be more performant.

awni commented 6 months ago

Right exactly.. I was wondering if you had some idea other than a new primitive, but that is the only way I've thought of so far 😄

If you are up for implementing it that would be great. You can see how our binary op primitives work as a good starting point. It migth 🤔 be worth doing something similar but ternary instead.

ozankabak commented 6 months ago

@Rifur13, I am also interested in this. Happy to help if you need

Rifur13 commented 6 months ago

Thanks! I'm almost done but we can iterate on a design in my PR if needed

awni commented 6 months ago

@Rifur13 are you still working on this one?

Rifur13 commented 6 months ago

Yep, let me clean up it up a bit and I'll send out the PR.