PPPLDeepLearning / plasma-python

PPPL deep learning disruption prediction package
http://tigress-web.princeton.edu/~alexeys/docs-web/html/
79 stars 43 forks source link

Loss scaling #30

Closed ASvyatkovskiy closed 6 years ago

ASvyatkovskiy commented 6 years ago

The pull request enables loss scaling and augments the global weight update to facilitate low precision calculation.

Loss scaling is controlled from the config:

model:
    loss_scale_factor: 1.0

which is propagated to plasma.models.targets, e.g. like:

class BinaryTarget(Target):
    activation = 'sigmoid'
    loss = 'binary_crossentropy'

    @staticmethod
    def loss_np(y_true,y_pred):
        return conf['model']['loss_scale_factor']*binary_crossentropy_np(y_true,y_pred)

Rearrange steps in the plasma.models.train_on_batch_and_get_deltas:

    #unscale before subtracting -- used to be subtract then unscale
    weights_before_update = multiply_params(weights_before_update,1.0/self.DUMMY_LR)
    weights_after_update = multiply_params(weights_after_update,1.0/self.DUMMY_LR)

    deltas = subtract_params(weights_after_update,weights_before_update)

    #unscale loss
    if conf['model']['loss_scale_factor'] != 1.0:
        deltas = multiply_params(deltas,1.0/conf['model']['loss_scale_factor'])