GalacticDynamics / coordinax

Coordinates in JAX
MIT License
10 stars 3 forks source link

Vector times array #173

Open adrn opened 2 months ago

adrn commented 2 months ago

Currently, this does not work:

>>> q = Quantity([1, 2, 3], "kpc")
>>> vec = cx.CartesianPosition3D.constructor(q)
>>> scale = jnp.array([1, 0.9, 0.85])
>>> vec * scale
EquinoxTracetimeError: ...

I think this should be supported when the shape of a multiplicative array is broadcastable with the vector shape, and it should do the same as q * scale.

(thinking about this as I work on a PR for a new AnisotropicScalingOperator)

nstarman commented 2 months ago

This is meant to be a shorthand for $\mathbf{v{\prime}} = S \mathbf{v}$ where

$$ S = \begin{bmatrix} s_1 & 0 & \cdots & 0 \ 0 & s_2 & \cdots & 0 \ \vdots & \vdots & \ddots & \vdots \ 0 & 0 & \cdots & s_n \end{bmatrix} $$

I'm fully in support of explicit matrix-vector multiplication.

scale = jnp.array([1, 0.9, 0.85])
jnp.diag(scale) @ vector

IMO less so for implicit m-v multiplication when it looks like a scalar-vector multiplication. For that, we should defer to existing numpy functions, e.g. numpy.matmul() semantics.

adrn commented 2 months ago

I hear you, but it seems like a waste to fill out a 2D matrix with zeros for a simple diagonal scaling operation! Vectors have a .norm() -- what about a .scale() like def scale(self, scale_factors: Array) -> PosT and etc.?

nstarman commented 2 months ago

That seems like a good workaround! We get to keep the mathematical rigour in multiplication, but offer the same convenience through a specific method

nstarman commented 1 month ago

On a related note, sparse array support would be useful here! https://jax.readthedocs.io/en/latest/jax.experimental.sparse.html