Closed jamescasbon closed 5 years ago
@jamescasbon do you have an example that shows a case where regularization loss is not taken into account? This code is currently a bit messy as part of it is in the heads implementation and part of it is in general_network, but, to the best of our knowledge, there is no bug here.
I don't have a handy example but I have the following evidence:
In ppo with use_kl_regularization True you can set the kl_coefficient to whatever you want and it will not affect the outcome of a run
In the ppo head I can see the regularization loss from (1) in the self.regularizations list. I can also confirm (by instrumenting the code) that this does not form part of the loss function used in general_network
Looking at the construction of the loss you have
self.losses = tf.losses.get_losses(self.full_name)
# L2 regularization
if self.network_parameters.l2_regularization != 0:
self.l2_regularization = tf.add_n([tf.nn.l2_loss(v) for v in self.weights]) \
* self.network_parameters.l2_regularization
self.losses += self.l2_regularization
self.total_loss = tf.reduce_sum(self.losses)
So a loss must be in tf.losses or be part of this l2 reg to be in the loss function.
You can see in head.py
# add losses and target placeholder
for idx in range(len(self.loss_type)):
...
# we add the loss to the losses collection and later we will extract it in general_network
tf.losses.add_loss(loss)
self.loss.append(loss)
# add regularizations
for regularization in self.regularizations:
self.loss.append(regularization)
So nowhere is the regularizer added to the loss
What evidence do you have that it does work? ;) I found the l2 regularizer bug before and I'm pretty sure this is the same. To fix that you added the stanza in the construction of the loss, right?
Thanks @jamescasbon :-) @shadiendrawis can you please review to make sure all regularizations are covered and are taken into account only once with this fix?
The regularization losses are not added to the cost function and are therefore ignored during optimisation. This PR fixes this bug.