...
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/var/folders/67/2zgxpmyd2z183j4k6r33nf740000gr/T/ipykernel_73958/1255407117.py in ?()
2 rho=Quantity([1.0, 2.0], "kpc"),
3 phi=Quantity([0.0, 0.2], "rad"),
4 z=Quantity(0.0, "kpc")
5 )
----> 6 q == q
[... skipping hidden 1 frame]
~/projects/coordinax/.venv/lib/python3.10/site-packages/jaxtyping/_decorator.py in ?(args, kwargs, bound, memos)
445 else:
446 raise TypeCheckError(msg) from e
447
448 # Actually call the function.
--> 449 out = fn(*args, **kwargs)
450
451 if full_signature.return_annotation is not inspect.Signature.empty:
452 # Now type-check the return value. We need to include the
~/projects/coordinax/src/coordinax/_src/vectors/base/base_pos.py in ?(self, other)
132 if not isinstance(other, AbstractPos):
133 return NotImplemented
134
135 rhs = other.represent_as(type(self))
--> 136 return super().__eq__(rhs)
[... skipping hidden 1 frame]
~/projects/coordinax/.venv/lib/python3.10/site-packages/jaxtyping/_decorator.py in ?(args, kwargs, bound, memos)
445 else:
446 raise TypeCheckError(msg) from e
447
448 # Actually call the function.
--> 449 out = fn(*args, **kwargs)
450
451 if full_signature.return_annotation is not inspect.Signature.empty:
452 # Now type-check the return value. We need to include the
~/projects/coordinax/src/coordinax/_src/vectors/base/base.py in ?(self, other)
614 if type(other) is not type(self):
615 return NotImplemented
616
617 comp_tree = tree.map(jnp.equal, self, other)
--> 618 comp_leaves = jnp.array(tree.leaves(comp_tree))
619 return jax.numpy.logical_and.reduce(comp_leaves)
~/projects/coordinax/.venv/lib/python3.10/site-packages/quax/_core.py in ?(self, *args, **kwargs)
320 dynamic,
321 is_leaf=_is_value,
322 )
323 fn, args, kwargs = eqx.combine(dynamic, static)
--> 324 out = fn(*args, **kwargs)
325 out = jtu.tree_map(ft.partial(_unwrap_tracer, trace), out)
326 return out
~/projects/coordinax/.venv/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py in ?(object, dtype, copy, order, ndmin, device)
5012 assert object.aval is not None
5013 out = _array_copy(object) if copy else object
5014 elif isinstance(object, (list, tuple)):
5015 if object:
-> 5016 out = stack([asarray(elt, dtype=dtype) for elt in object])
5017 else:
5018 out = np.array([], dtype=dtype)
5019 elif _supports_buffer_protocol(object):
~/projects/coordinax/.venv/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py in ?(arrays, axis, out, dtype)
4103 axis = _canonicalize_axis(axis, len(shape0) + 1)
4104 new_arrays = []
4105 for a in arrays:
4106 if shape(a) != shape0:
-> 4107 raise ValueError("All input arrays must have the same shape.")
4108 new_arrays.append(expand_dims(a, axis))
4109 return concatenate(new_arrays, axis=axis, dtype=dtype)
ValueError: All input arrays must have the same shape.
Darn. ok. need to mess around with the full_shaped utility.
I would like to come up with a better solution that doesn't require reshaping / array copying.
Example: