RobertTLange / evosax

Evolution Strategies in JAX 🦎
Apache License 2.0
475 stars 44 forks source link

[Feature request] Integration with equinox #52

Open AntonyJia159 opened 1 year ago

AntonyJia159 commented 1 year ago

Equinox is a minimalistic jax nn library. Currently, evosax's "parameter_reshaper" doesn't seem to support it, as the returned modules will have an population (batch) dimension in their weights: Something like rng = jax.random.PRNGKey(0) network = eqx.nn.GRUCell(16,20,key=rng) param_reshaper = ParameterReshaper(network) gets:

'GRUCell( weight_ih=f32[1,60,16], weight_hh=f32[1,60,20], bias=f32[1,60], bias_n=f32[1,20], input_size=16, hidden_size=20, use_bias=True ) For a population of size one, The correct form should be multiple cells like this: GRUCell( weight_ih=f32[60,16], weight_hh=f32[60,20], bias=f32[60], bias_n=f32[20], input_size=16, hidden_size=20, use_bias=True )` It's possible that there was just something I didn't figure out about the reshaper thing. If that's the case, please kindly inform me of the issue. Evosax is an awesome project, and more integration into the whole ecosystem would make it even better : )

RobertTLange commented 1 year ago

Thank you for raising this @AntonyJia159! I would love to support Equinox -- although I have to say that I never worked with it and as you correctly pointed out evosax so far only support neural network libraries which use weight pytrees more explicitly. ParameterReshaper requires a pytree as input in order to extract the correct shapes to reshape the flat parameter vectors used in ES. So far I couldn't figure out how to elegantly extract these from equinox modules -- but I am sure that this is possible. Furthermore, we might need a smooth way how to plug the proposed ES candidate weights back in for forward passes. Please let me know if this makes sense and if you have a proposal! Cheers, Rob

hctomkins commented 8 months ago

FYI @RobertTLange @AntonyJia159 - Equinox works ok out the box! You just need to vmap the evaluation of the networks too:

import jax
import jax.numpy as jnp
from evosax import CMA_ES, ParameterReshaper
import equinox as eqx

def fitness(net, input, target):
    output = net(input)
    return jnp.mean((output - target) ** 2)

if __name__ == "__main__":
    # Set up single equinox network
    rng = jax.random.PRNGKey(0)
    fixed_input = jax.random.uniform(key=rng, shape=(16,))
    fixed_output = 18.0

    network = eqx.nn.Linear(16, 1, key=rng)
    print("Example single random fitness ", fitness(network, fixed_input, fixed_output))

    # Set up for evosax
    reshaper = ParameterReshaper(network)
    fitness_many = jax.vmap(fitness, in_axes=(0, None, None)) # Adjust 'None's if working on batched rather than fixed data

    # Instantiate the search strategy
    strategy = CMA_ES(popsize=20, num_dims=reshaper.total_params, elite_ratio=0.5)
    es_params = strategy.default_params
    state = strategy.initialize(rng, es_params)

    # Run ask-eval-tell loop
    for t in range(10):
        rng, rng_gen, rng_eval = jax.random.split(rng, 3)
        candidate_params, state = strategy.ask(rng_gen, state, es_params)
        candidate_networks = reshaper.reshape(candidate_params)
        fitnesses = fitness_many(candidate_networks, fixed_input, fixed_output)
        print(jnp.min(fitnesses))
        state = strategy.tell(candidate_params, fitnesses, state, es_params)

    # Get best overall population member & its fitness
    print("best fitness:", state.best_fitness)
    print("best params:", state.best_member)

Gives:

Example single fitness  311.59064
ParameterReshaper: 17 parameters detected for optimization.
173.77896
34.950092
19.661118
3.1265647
0.10655613
0.3538381
0.049168102
0.32669815
0.00019323104
0.019950658
best fitness: 0.00019323104
best params: [ 2.3055873   2.9635496   5.8396025   4.113425    0.47761774  2.2564042
  2.6945195  -1.506904   -0.23346913 -1.9476291   3.9967122  -1.9457947
 -3.7017791  -1.44961     0.69735444  1.0968039   5.4032116 ]

Hopefully this is helpful to you guys / others!