There is a bug in the undo_flatten method inside the CMA optimizer. The weights are not well positioned back in place because length_flat_layer contains lengths and not indices.
for i, layer_shape in enumerate(self.shape):
flat_layer = flattened_weights[
self.length_flat_layer[i]: self.length_flat_layer[i] + self.length_flat_layer[i + 1]
]
There is a bug in the undo_flatten method inside the CMA optimizer. The weights are not well positioned back in place because length_flat_layer contains lengths and not indices.
This could be a quick fix