nnaisense / evotorch

Advanced evolutionary computation library built directly on top of PyTorch, created at NNAISENSE.
https://evotorch.ai
Apache License 2.0
1k stars 62 forks source link

KeyError: 'total_interaction_count' with SNES #58

Closed maulberto3 closed 1 year ago

maulberto3 commented 1 year ago

When doing a search as searcher = SNES(prob, stdev_init=0.1, num_interactions=100) and then run it, I get the error.

engintoklu commented 1 year ago

Hi @maulberto3!

Thank you very much for your feedback!

The keyword argument num_interactions was introduced to be used with reinforcement learning problems which provide an item with the key total_interaction_count in their status dictionaries. Could it be that your problem object (prob in your example code) is not a reinforcement learning problem, and therefore its status does not have the key total_interaction_count?

I looked at the docstrings of SNES and realized that the documentation for the keyword argument num_interactions says that it works with a problem class named GymProblem. I have to fix the documentation because the correct class name is GymNE (full name of the class: evotorch.neuroevolution.GymNE). Perhaps this mismatch in the documentation caused this confusion? Sorry about that. I will fix the docstrings.

Is it indeed the case that your problem object is not a reinforcement learning task?

NaturalGradient commented 1 year ago

Assuming that the issue is as described by @engintoklu , I'm closing this issue for now. @maulberto3 please reach out and re-open the issue if you believe there is a bug in the usage of num_interactions

afzalmushtaque commented 10 months ago

This error occurs if I use the VecGymNE class because it's missing the hook and extra stats that GymNE class has:

# evotorch/neuroevolution/gymne.py
class GymNE(NEProblem):
    def __init__(...):
        ...
        self.after_eval_hook.append(self._extra_status)
    def _extra_status(self, batch: SolutionBatch):
        return dict(total_interaction_count=self.interaction_count, total_episode_count=self.episode_count)

Workaround is to manually add this hook after initializing VecGymNE object:

problem = VecGymNE(...)
problem.after_eval_hook.append(lambda batch: dict(total_interaction_count=problem.interaction_count, total_episode_count=problem.episode_count))