PacktPublishing / Deep-Reinforcement-Learning-Hands-On

Hands-on Deep Reinforcement Learning, published by Packt
MIT License
2.83k stars 1.28k forks source link

Saving/Running Model from 08_dqn_rainbow.py #38

Closed icompute386 closed 5 years ago

icompute386 commented 5 years ago

Hi again, I was hoping you could help me modify the code for 08_dqn_rainbow.py to save the model. So that it can be replayed through dqn_play.py

Here's what I have so far, it fails on: state, reward, done, _ = env.step(action)

MODIFIED CODE OF (dqn_play.py from chapter 6)

import torch.nn as nn

Vmax = 10 Vmin = -10 N_ATOMS = 51 DELTA_Z = (Vmax - Vmin) / (N_ATOMS - 1)

class RainbowDQN(nn.Module):

if name == "main": net = RainbowDQN(env.observation_space.shape, env.action_space.n)

MODIFIED CODE OF (common.py from chapter 7)

class RewardTracker: def init(self, writer, stop_reward, env_name, net): self.writer = writer self.stop_reward = stop_reward self.best_mean_reward = None self.env_name = env_name self.net = net

def reward(self, reward, frame, epsilon=None):
    self.total_rewards.append(reward)
    speed = (frame - self.ts_frame) / (time.time() - self.ts)
    self.ts_frame = frame
    self.ts = time.time()
    mean_reward = np.mean(self.total_rewards[-100:])
    epsilon_str = "" if epsilon is None else ", eps %.2f" % epsilon
    print("%d: done %d games, mean reward %.3f, speed %.2f f/s%s" % (
        frame, len(self.total_rewards), mean_reward, speed, epsilon_str
    ))
    sys.stdout.flush()
    if epsilon is not None:
        self.writer.add_scalar("epsilon", epsilon, frame)
    self.writer.add_scalar("speed", speed, frame)
    self.writer.add_scalar("reward_100", mean_reward, frame)
    self.writer.add_scalar("reward", reward, frame)
    if self.best_mean_reward is None or self.best_mean_reward < mean_reward:
        torch.save(self.net.state_dict(), self.env_name + "-best.dat")
        if self.best_mean_reward is not None:
            print("Best mean reward updated %.3f -> %.3f, model saved" % (self.best_mean_reward, mean_reward))
        self.best_mean_reward = mean_reward
    if mean_reward > self.stop_reward:
        print("Solved in %d frames!" % frame)
        return True
    return False

with common.RewardTracker(writer, params['stop_reward'], params['env_name'], net) as reward_tracker: