google-deepmind / dm-haiku

JAX-based neural network library
https://dm-haiku.readthedocs.io
Apache License 2.0
2.91k stars 231 forks source link

Stateful functions and target networks? #49

Closed KristianHolsheimer closed 4 years ago

KristianHolsheimer commented 4 years ago

Hi there,

I'd like to ask for some advice.

Let's say I've got a function approximator for a q-function that uses hk.{get,set}_state() along with hk.transform_with_state(). This means that my function approximator consists of a triplet func, params, state.

I would like to keep a separate copy of this function approximator, i.e. a target network. This means that I keep separate copy of the params.

Now my question is, would you recommend I also keep a separate copy of the state? And if so, how do we ensure that the usual smooth updates make sense? (e.g. variance typically can't be updated this way unless you map it to the 2nd moment first)

KristianHolsheimer commented 4 years ago

The component I'm thinking of here is hk.BatchNorm. The state (average mean/variance of the activations) clearly depends on the current params.

My intuition doesn't really help me here, because I feel like I can wave my hands and tell myself two opposing stories. One story might be that leakage from primary to target networks isn't nearly as bad as leakage in supervised learning. The other story might be to imagine making this primary-to-target leakage extreme by picking a small batch size and small momentum parameter.

Anyway, I'm curious to hear your experience/opinion on this.

cgarciae commented 4 years ago

In other frameworks you would probably keep a copy of both when implementing this without thinking too much about it since most frameworks don't treat state weights differently (apart from stopping the gradients). Not sure if this is good, but maybe gives you more confidence if you think how others implement it.

trevorcai commented 4 years ago

In the specific case of target networks as used in DQN and related literature, the faithful thing to do would be to freeze both params and state when updating your target network. This means setting is_training=False for BatchNorm.

As you mention correctly, batchnorm state and params have strong co-dependence. This is particularly tricky in DRL, since the data distribution (and therefore the ideal values for BN state) changes with your policy, which are presumably parameterized with params.

I don't follow your concerns with the "smoothness" of second moment updates. If you're operating in the typical DQN target network regime, you'll be smoothly updating your second moment estimates in your online network. The target network simply serves as a snapshot of the online network (to fight overestimation etc).

That said, it's possible that leaking online experience to your target network helps - but I'd generally want to see a good justification for why that might be good behavior. (Metapoint: the cool thing about JAX + Haiku is that it makes these decisions explicit, rather than implicit!)

KristianHolsheimer commented 4 years ago

@trevorcai First of all, on your meta point.. I totally agree, this is exactly why I love jax and haiku.

My concern about smooth target-network updates was that some state may not be updated properly by doing

target_state = (1 - tau) * target_state + tau * primary_state

This works for point estimates (e.g. model weights or batchnorm means), but the batchnorm variances don't combine this way. Instead, you might want to update the batchnorm state as

moments = array([mean, var + mean ** 2])
moments = (1 - tau) * moments + tau * moments_new
mean, var = moments[0], moments[1] - moments[0] ** 2

But come to think of it.. Although it's a headache to treat batchnorm state/params differently from others, but it's certainly doable.

Thanks for your comments, I was about to implement my TD learning updaters today, so this is very useful!