GalacticDynamics / coordinax

Coordinates in JAX
MIT License
12 stars 3 forks source link

Testing equality with array values fails #246

Open adrn opened 3 days ago

adrn commented 3 days ago

Example:

q = cx.CylindricalPos(
    rho=Quantity([1.0, 2.0], "kpc"),
    phi=Quantity([0.0, 0.2], "rad"),
    z=Quantity(0.0, "kpc")
)
q == q
...
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/var/folders/67/2zgxpmyd2z183j4k6r33nf740000gr/T/ipykernel_73958/1255407117.py in ?()
      2     rho=Quantity([1.0, 2.0], "kpc"),
      3     phi=Quantity([0.0, 0.2], "rad"),
      4     z=Quantity(0.0, "kpc")
      5 )
----> 6 q == q

    [... skipping hidden 1 frame]

~/projects/coordinax/.venv/lib/python3.10/site-packages/jaxtyping/_decorator.py in ?(args, kwargs, bound, memos)
    445                         else:
    446                             raise TypeCheckError(msg) from e
    447 
    448                 # Actually call the function.
--> 449                 out = fn(*args, **kwargs)
    450 
    451                 if full_signature.return_annotation is not inspect.Signature.empty:
    452                     # Now type-check the return value. We need to include the

~/projects/coordinax/src/coordinax/_src/vectors/base/base_pos.py in ?(self, other)
    132         if not isinstance(other, AbstractPos):
    133             return NotImplemented
    134 
    135         rhs = other.represent_as(type(self))
--> 136         return super().__eq__(rhs)

    [... skipping hidden 1 frame]

~/projects/coordinax/.venv/lib/python3.10/site-packages/jaxtyping/_decorator.py in ?(args, kwargs, bound, memos)
    445                         else:
    446                             raise TypeCheckError(msg) from e
    447 
    448                 # Actually call the function.
--> 449                 out = fn(*args, **kwargs)
    450 
    451                 if full_signature.return_annotation is not inspect.Signature.empty:
    452                     # Now type-check the return value. We need to include the

~/projects/coordinax/src/coordinax/_src/vectors/base/base.py in ?(self, other)
    614         if type(other) is not type(self):
    615             return NotImplemented
    616 
    617         comp_tree = tree.map(jnp.equal, self, other)
--> 618         comp_leaves = jnp.array(tree.leaves(comp_tree))
    619         return jax.numpy.logical_and.reduce(comp_leaves)

~/projects/coordinax/.venv/lib/python3.10/site-packages/quax/_core.py in ?(self, *args, **kwargs)
    320                 dynamic,
    321                 is_leaf=_is_value,
    322             )
    323             fn, args, kwargs = eqx.combine(dynamic, static)
--> 324             out = fn(*args, **kwargs)
    325             out = jtu.tree_map(ft.partial(_unwrap_tracer, trace), out)
    326             return out

~/projects/coordinax/.venv/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py in ?(object, dtype, copy, order, ndmin, device)
   5012     assert object.aval is not None
   5013     out = _array_copy(object) if copy else object
   5014   elif isinstance(object, (list, tuple)):
   5015     if object:
-> 5016       out = stack([asarray(elt, dtype=dtype) for elt in object])
   5017     else:
   5018       out = np.array([], dtype=dtype)
   5019   elif _supports_buffer_protocol(object):

~/projects/coordinax/.venv/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py in ?(arrays, axis, out, dtype)
   4103     axis = _canonicalize_axis(axis, len(shape0) + 1)
   4104     new_arrays = []
   4105     for a in arrays:
   4106       if shape(a) != shape0:
-> 4107         raise ValueError("All input arrays must have the same shape.")
   4108       new_arrays.append(expand_dims(a, axis))
   4109     return concatenate(new_arrays, axis=axis, dtype=dtype)

ValueError: All input arrays must have the same shape.
nstarman commented 3 days ago

Darn. ok. need to mess around with the full_shaped utility. I would like to come up with a better solution that doesn't require reshaping / array copying.