import numpy as np
import jax
import mudata
x = jax.random.normal(key=jax.random.PRNGKey(1), shape=(100,20))
y = jax.random.normal(key=jax.random.PRNGKey(1), shape=(110,30))
m = mudata.MuData({"x": mudata.AnnData(np.array(x)), "y": mudata.AnnData(np.array(y))})
# => NotImplementedError: var_names seem to have been renamed and filtered at the same time.
# There is no way to restore the order. MuData object has to be re-created from these modalities:
# mdata1 = MuData(mdata.mod)