google / evojax

Apache License 2.0
834 stars 85 forks source link

can one specify parts of the model that are non differentiable? #41

Closed jakubMitura14 closed 2 years ago

jakubMitura14 commented 2 years ago

I have a model in Jax (convolutional neural network with some modifications) where most is fully differentiable, but parts are not - can I mark somehow which parts are not differentiable so to have correct gradient backpropagation, or it is done automatically?

Thanks!

lerrytang commented 2 years ago

Hi,

Thanks for using evojax and the question. Evojax has no knowledge of the differentiability of the policy network, it simply optimizes what is given. While I'm not very certain about your usage, the following method may help.

# Only ask the solver for the parameters that belong to the non-diff operations.
solver = NeuroEvolutionAlgo(param_size=non_diff_param_size)

# Define the policy to contain both diff and non-diff operations.
class AwesomePolicy:

   def __init__(self):
       # policy network definition and init
       # Save the part that are differentiable.
       self.diff_params = xxxxx

  def reset():
      # Save the differentiable params in the policy state.
      return PolicyState(params=self.diff_params)

   def get_action(self, params, p_state):
       # The "params" parameter is coming from the solver, they belong to the non-diff ops.
        load_diff_params_to_model(p_state.params)
        load_non_diff_params(params)
        # Model inference

Let me know if that works or when you have further questions.

jakubMitura14 commented 2 years ago

Thank you! I will experiment on this, and not to clutter your issues close for now, in case of further problems will reopen. And once more time thanks for publishing this fantastic work !!