ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
17.09k stars 988 forks source link

[BUG] Unstable results in sin/arcsin/arccos calls #450

Open Redmept1on opened 10 months ago

Redmept1on commented 10 months ago

Describe the bug I once reproduced a bug in Pytroch, and when I converted the Pytroch method to the MLX method, I found that the problem still existed. This problem was identified as a bug in Pytroch, but it is the same as the issue I submitted before, https://github.com/ml-explore/mlx/issues/439, which may be due to the different math libraries used by the CPU and GPU. I want to know if this issue is also a bug for MLX.

I get unstable results when sin, arcsin, arccos are called sequentially. sin followed by arcsin should give by the same value, however the calculation is not exactly the same and we get unstable results after arccos: 0 on CPU and nan on GPU.

To Reproduce

Include code snippet

import mlx.core as mx

x = mx.ones(2)
y = mx.sin(x)
y = mx.arcsin(y)  # inverse operations, gives [1, 1]
y = mx.arccos(y)
print(mx.default_device(),y) # gives [nan, nan]

mx.set_default_device(mx.cpu)
y1 = mx.sin(x)
y1 = mx.arcsin(y1) # gives [1, 1]
y1 = mx.arccos(y1)
print(mx.default_device(),y1) # gives [0, 0]

image

Desktop (please complete the following information):

awni commented 10 months ago

This is interesting:

x = mx.array(1.0)
y = mx.arcsin(mx.sin(x))
print(y > 1.0) # Evaluates to True

In infinite precision it should give 1.0 precisely, but it does not. Once it's > 1.0, calling arccos is invalid since its only defined on [-1.0, 1.0].

This seems like a problem with metal::precise::asin(x); (which is a Metal function and likely why you see the same issue in PyTorch).

I would call this a bug but I do not know if we can fix it in our code since we can't clamp the inputs to asin and it's impossible to know in the general case if they should be valid. The alternative would be to have a custom implementation of asin .. which is not inconceivable, but more involved and may have a performance cost associated to it.

I will label this as Bug for now until we figure out how to deal with it.