adaptive-intelligent-robotics / QDax

Accelerated Quality-Diversity
https://qdax.readthedocs.io/en/latest/
MIT License
258 stars 42 forks source link

MOME Container Implementation does not support non-array Genotypes #139

Open Lookatator opened 1 year ago

Lookatator commented 1 year ago

Hi,

It seems the current implementation of the MOME Container does not work if the Genotype is a pytree more general than an array. For example the following line fails to run for Neural Network Genotypes: https://github.com/adaptive-intelligent-robotics/QDax/blob/45c4c2a3eaaae7e78323ad1b19172e55ca68cd3d/qdax/core/containers/mome_repertoire.py#L314

We have a working fix in the lab; which we will clean before submitting it in a PR :)

limbryan commented 1 year ago

I think this has been fixed and addressed now with #143 right @felixchalumeau? Where we change the explicit type in the repertoire to account for non-array genotypes and arbitrary pytrees (https://github.com/adaptive-intelligent-robotics/QDax/blob/773e646b3bd1bf3fb1a1541686c7332193129e78/qdax/core/containers/mapelites_repertoire.py#L390)

Lookatator commented 1 day ago

Mmh from what I see, the issue still seems to be on the main and develop branches.

What I meant is cell_genotype.squeeze(axis=0) fails to run if cell_genotype is not an array.

Supposedly, the line should be:

jax.tree_util.tree_map(lambda x: x.squeeze(axis=0), cell_genotype)

And the line below: new_batch_of_genotypes=jnp.expand_dims(genotype, axis=0) should also be updated I guess.

@hannah-jan what do you think?