Closed jakubMitura14 closed 2 years ago
Hi,
Thanks for using evojax and the question. Evojax has no knowledge of the differentiability of the policy network, it simply optimizes what is given. While I'm not very certain about your usage, the following method may help.
# Only ask the solver for the parameters that belong to the non-diff operations.
solver = NeuroEvolutionAlgo(param_size=non_diff_param_size)
# Define the policy to contain both diff and non-diff operations.
class AwesomePolicy:
def __init__(self):
# policy network definition and init
# Save the part that are differentiable.
self.diff_params = xxxxx
def reset():
# Save the differentiable params in the policy state.
return PolicyState(params=self.diff_params)
def get_action(self, params, p_state):
# The "params" parameter is coming from the solver, they belong to the non-diff ops.
load_diff_params_to_model(p_state.params)
load_non_diff_params(params)
# Model inference
Let me know if that works or when you have further questions.
Thank you! I will experiment on this, and not to clutter your issues close for now, in case of further problems will reopen. And once more time thanks for publishing this fantastic work !!
I have a model in Jax (convolutional neural network with some modifications) where most is fully differentiable, but parts are not - can I mark somehow which parts are not differentiable so to have correct gradient backpropagation, or it is done automatically?
Thanks!