EMI-Group / tensorneat

GPU-accelerated NeuroEvolution of Augmenting Topologies (NEAT)
BSD 3-Clause "New" or "Revised" License
65 stars 11 forks source link

Brax HyperNEAT FullSubstrate with Hidden Nodes implementation #3

Closed holyPancakes closed 6 months ago

holyPancakes commented 6 months ago

I'm trying to create an implementation of Brax HyperNEAT and I'm using the reacher environment as a trial experiment (FullSubstrate, 12 input (1bias), 12 hidden, 2 output). There's an error with the number of inputs that is being checked in HyperNEAT class initialization

assert substrate.query_coors.shape[1] == neat.num_inputs, \
    "Substrate input size should be equal to NEAT input size" 

Did I understand the implementation correctly that substrate.query_coors is the list of all connections from input to hidden, hidden to hidden, and hidden to output all combined?

The substrate.query_coors looks like this

[[-1.    -1.    -1.    -0.5  ]
 ...
 [ 1.     0.5    0.333  1.   ]]

And substrate.query_coors.shape[1] is always 4 since it represents [[x1,y1,x2,y2]...[xn,yn,xm,ym]] even if neat.num_inputs is 12 since reacher environment has 11 inputs + 1 bias

I checked the implementation in xor3d_hyperneat.py and the substrate.query_coors.shape[1] is coincidentally 4 since XOR has 3 inputs + 1 bias.

If I try to bypass the assert and comment it out, I get this error:

Traceback (most recent call last):
  File "C:\Users\Admin\Desktop\tensorneat\reacherHyperNEAT.py", line 51, in <module>
    state, best = pipeline.auto_run(state)
                  ^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Admin\Desktop\tensorneat\pipeline.py", line 83, in auto_run
    compiled_step = jax.jit(self.step).lower(ini_state).compile()
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Admin\Desktop\tensorneat\pipeline.py", line 63, in step
    pop_transformed = jax.vmap(self.algorithm.transform)(pop)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Admin\Desktop\tensorneat\algorithm\hyperneat\hyperneat.py", line 58, in transform
    query_res = jax.vmap(self.neat.forward, in_axes=(0, None))(self.substrate.query_coors, transformed)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Admin\Desktop\tensorneat\algorithm\neat\neat.py", line 57, in forward
    return self.genome.forward(inputs, transformed)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Admin\Desktop\tensorneat\algorithm\neat\genome\default.py", line 54, in forward
    ini_vals = ini_vals.at[self.input_idx].set(inputs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Admin\AppData\Local\Programs\Python\Python312\Lib\site-packages\jax\_src\numpy\array_methods.py", line 490, in set
    return scatter._scatter_update(self.array, self.index, values, lax.scatter,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Admin\AppData\Local\Programs\Python\Python312\Lib\site-packages\jax\_src\ops\scatter.py", line 80, in _scatter_update
    return _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Admin\AppData\Local\Programs\Python\Python312\Lib\site-packages\jax\_src\ops\scatter.py", line 118, in _scatter_impl
    y = jnp.broadcast_to(y, tuple(indexer.slice_shape))
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Admin\AppData\Local\Programs\Python\Python312\Lib\site-packages\jax\_src\numpy\lax_numpy.py", line 1229, in broadcast_to
    return util._broadcast_to(array, shape)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Admin\AppData\Local\Programs\Python\Python312\Lib\site-packages\jax\_src\numpy\util.py", line 430, in _broadcast_to
    raise ValueError(msg.format(arr_shape, shape))
ValueError: Incompatible shapes for broadcasting: (4,) and requested shape (12,)

Anyone having the same issue?

holyPancakes commented 6 months ago

OHHH I get it now. The NEAT inputs should always be 4 since it's supposed to be the inputs for the CPPN.

WLS2002 commented 5 months ago

Thank you very much for trying out our library! I'm sorry for not getting back to you promptly. I'll make sure to check the project's issues more frequently in the future, so please feel free to continue raising any issues you encounter.