ufal / npfl122

NPFL122 repository
Creative Commons Attribution Share Alike 4.0 International
13 stars 23 forks source link

Simplify and improve pytorch code for QNetwork #81

Closed Fassty closed 2 years ago

Fassty commented 2 years ago

Regarding the copying of weights. You probably want to have both the policy and target network in the same Network class as it's needed to calculate the target Q values for computing the loss. So the copy_weights_from should rather look like:

def update_target_net(self):
    # torch.no_grad is not required here
    self._target_net.load_state_dict(self._model.state_dict())

or we could provide a method for doing the soft (Polyak) update right away

def soft_update_weights(self, tau):
    with torch.no_grad():
        for param, target_param in zip(self._model.parameters(), self._target_net.parameters()):
            target_param.data.mul_(1 - tau)
            torch.add(target_param.data, param.data, alpha=tau, out=target_param.data)
foxik commented 2 years ago

Yes, there are definitely many ways how the copying/averaging can be done. In the template, I prefer the "separate Network, hard copying" approach, because I find the separate Network easier at first (and also you do not want to serialize it after training), and the hard copying is what DQNs are doing. (But in later papers we will see also the soft EMA; note that "Polyak" averaging is arithmetic average, not the exponential moving average you implemented in the second code above.)

However, thanks for the simplified code, that is definitely better!

Regarding the expansion of the state -- that works fine if the state is a numpy array; I am not sure it is promised in the documentation... (reading through the docs) Oh, it is -- great. Then I am merging it and will rely on it (but I will use the syntax state[np.newaxis], which I like better).

Also, just for fun:

In [2]: %timeit np.array([s])
603 ns ± 6.18 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

In [3]: %timeit s[np.newaxis]
168 ns ± 1.38 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)

In [4]: %timeit np.asarray(s[np.newaxis])
226 ns ± 2.72 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

A community-work point is yours, BTW.