keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.06k stars 19.35k forks source link

Feature request: keras.ops.linalg.lstsq #19678

Closed jonbarron closed 1 week ago

jonbarron commented 1 week ago

Can this be added? It seems like it should be straightforward to implement, as it's just a thin wrapper around a SVD (in JAX at least https://github.com/google/jax/blob/main/jax/_src/numpy/linalg.py#L1368-L1406), and y'all already have SVD implemented.

fchollet commented 1 week ago

Yes, that's in scope and it seems straightforward.

I'm curious, what are you building that requires all these niche linalg ops?

fchollet commented 1 week ago

I've investigated this, and it's actually impossible to achieve consistency across backends for the resid and s return values (the solutions x, the first returned value, is fine). Even jax.numpy isn't consistent with numpy. Torch does it differently from both as well (despite having the same API).

We could make the function only return x. Are the other values ever useful?

Looking at code in the wild I could only find samples that used x. Other values area apparently always discarded. It's a pretty weird API tbh.

jonbarron commented 1 week ago

yeah I think it makes sense to only return x, and I can imagine very few use-cases where the caller would really care about the other outputs. It might be helpful to return None for the other 3 outputs of the numpy interface just so it's a drop-in replacement for JAX. The other upside to having placeholder None outputs is that it avoids a potential footgun in case someone happens to be solving for an x with x.shape[0] == 4, as if they call lstsq with 4 output slots it will silently unstack the tensor along the first dimension which is definitely not what the caller would want.

fchollet commented 1 week ago

For reference, TensorFlow's tf.linalg.lstsq only returns x https://www.tensorflow.org/api_docs/python/tf/linalg/lstsq (though it also has a different signature).

I added the API, only returning x. Returning None entries would be problematic since ops are only ever supposed to return tensors.