liziniu / RL-PPO-Keras

Proximal Policy Optimization(PPO) with Keras Implementation
17 stars 12 forks source link

Loading models doesn't work #1

Open ArnoDekkersLos opened 5 years ago

ArnoDekkersLos commented 5 years ago

Hoping that this project is not abandoned and you're willing to patch this: When trying to load a saved model using the Agent.load_model method it will throw the exception: 'keras load ValueError: Unknown loss function:loss'

The regular solution is to change the line: self.actor_network = load_model(self.dic_path["PATH_TO_MODEL"], "%s_actor_network.h5") (which I already changed to): self.actor_network = load_model("ppo/actor_network.h5") to: self.actor_network = load_model("ppo/actor_network.h5", custom_objects={'loss': self.loss})

However because the loss function is an inner function that cannot be called. When trying to use use proximal_policy_optimization_loss(which generates the loss function) instead it'll throw the exception: 'AttributeError: 'function' object has no attribute 'get_shape'

I've been trying to fix this by: loading weights rather then the model creating a lose loss function using self.parameters within creating a lose loss function and use lampda https://stackoverflow.com/a/54177997/8579225

but I can't seem to fix things. Hope you are willing to help me out with this.

navallo commented 5 years ago

It's really late but here is a solution 1, loading weights rather than model 2, use 'build_network_from_copy' rather than 'deepcopy'

That is:

    def save_model_weights(self, file_name):
        self.actor_network.save_weights(os.path.join(self.dic_path["PATH_TO_MODEL"], "%s_actor_weights.h5" % file_name))
        self.critic_network.save_weights(os.path.join(self.dic_path["PATH_TO_MODEL"], "%s_critic_weights.h5" % file_name))

    def load_model_weights(self, file_name):
        self.actor_network.load_weights(os.path.join(self.dic_path["PATH_TO_MODEL"], "%s_actor_weights.h5" % file_name))
        self.critic_network.load_weights(os.path.join(self.dic_path["PATH_TO_MODEL"], "%s_critic_weights.h5" % file_name))
        self.actor_old_network = self.build_network_from_copy(self.actor_network)