greentfrapp / lucent

Lucid library adapted for PyTorch
Apache License 2.0
597 stars 89 forks source link

A More Elegant Dead ReLU Fix #9

Open greentfrapp opened 4 years ago

greentfrapp commented 4 years ago

To quote from redirected_relu_grad.py in the original Lucid library:

When we visualize ReLU networks, the initial random input we give the model may not cause the neuron we're visualizing to fire at all. For a ReLU neuron, this means that no gradient flow backwards and the visualization never takes off. One solution would be to find the pre-ReLU tensor, but that can be tedious. These functions provide a more convenient solution: temporarily override the gradient of ReLUs to allow gradient to flow back through the ReLU -- even if it didn't activate and had a derivative of zero -- allowing the visualization process to get started. These functions override the gradient for at most 16 steps. Thus, you need to initialize global_step before using these functions.

Lucid uses tensorflow, which allows for gradient overrides with gradient_override_map (although Lucid overrides that with their own implementation). It is also possible to keep track of the global step in tensorflow, and this is used in Lucid to make the "gradient fix" temporary (see above).

In comparison, Lucent implements a hacky workaround that is much less sophisticated.

We simply replace the ReLU function with our own RedirectedReLU, which has a modified backward method. When the gradient should be 0 (because of negative output), we simply scale the gradient by 0.1 and let it through. See here for the exact implementation.

We do this at the model initialization stage only for the InceptionV1 model and we never switch off the redirected gradient. I suspect that not switching it off is not as bad as we might imagine, because we are updating the input values instead of the model weights. In any case, this seems to work fine so far, but I would really prefer a more principled approach.

To be frank, I haven't spent too much time thinking about this with torch. But here are the main elements of a better fix, primarily following Lucid's implementation:

Questions and discussions welcomed!

iacolippo commented 4 years ago

When I tried (and failed :rofl: ) building a Pytorch port for Lucid, I did the following to implement the redirected relu, in the optvis.render_vis function: conditionally open some contexts with ExitStack where I override relu with redirected relu, close the contexts at epoch 15. Some code here:

images = []
try:
    with ExitStack() as stack:
        if relu_gradient_override:
            from limpid.misc.redirected_relu_grad import redirect_relu_F, redirect_relu_nn
            # when entering these contexts, the relu becomes redirected
            stack.enter_context(redirect_relu_F())
            stack.enter_context(redirect_relu_nn())

        model = model()
        for epoch in range(n_epochs):
            print(epoch)
            optimizer.zero_grad()
            out = model(image)
            loss = objective_f(out)
            loss.backward()
            optimizer.step()

            if epoch in thresholds:
                if verbose:
                    print('Epoch: [{}/{}] Loss: {:.4f}'.format(epoch, n_epochs, loss))
                images.append(image)

            if epoch == 15:
                # stop using redirected relu grad trick at 16th epoch - closes the contexts previously opened
                print('closing stack')
                stack.close()

where the context managers are defined like this:

@contextmanager
def redirect_relu_F():
    temp = getattr(_F, 'relu')
    setattr(_F, 'relu', lambda x, inplace: _redirected_relu_func(x, inplace))
    yield
    setattr(_F, 'relu', temp)

@contextmanager
def redirect_relu_nn():
    temp = getattr(_torch.nn, 'ReLU')
    setattr(_torch.nn, 'ReLU', lambda inplace: RedirectedReLU())
    yield
    setattr(_torch.nn, 'ReLU', temp)

class RedirectedReluFunction(_Function):
    @staticmethod
    def forward(ctx, input, inplace=False):
        print('forward')
        ctx.save_for_backward(input)
        if inplace:
            output = _torch.relu_(input)
        else:
            output = _torch.relu(input)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input = ctx.saved_tensors[0]
        grad_input = None

        if ctx.needs_input_grad[0]:
            # correct gradient
            grad_input = (grad_output > 0).float()
            # now where the gradient is zero, modify it to push it
            # where it becomes positive tf.zeros_like(grad), grad
            redirected_grad_input = _torch.where((input < 0) | (grad_input > 0),
                                                 _torch.zeros(grad_input.size()), grad_input)

            # only use redirected gradient where nothing got through original gradient
            grad_input_reshaped = grad_input.view(grad_input.size(0), -1)
            grad_mag = _torch.norm(grad_input_reshaped, dim=1)
            grad_input = _torch.where(grad_mag > 0., grad_input, redirected_grad_input)

        # gradient wrt inplace variable is always None
        return grad_input, None

_redirected_relu_func = RedirectedReluFunction.apply

class RedirectedReLU(_torch.nn.Module):
    def __init__(self, inplace=False):
        super(RedirectedReLU, self).__init__()
        self.inplace = inplace

    def forward(self, input):
        return _redirected_relu_func(input, inplace=self.inplace)

I'm not sure the whole implementation is correct, but you get the gist.

greentfrapp commented 4 years ago

Oh yes, I just recently looked into context managers and was wondering if that might work for this! Thank you @iacolippo for the tip! I'll take a closer look soon and hopefully get around to implementing this.

iacolippo commented 4 years ago

I don't have a lot of time on my hands, but happy to help for feedback and/or minor tasks if needed. Cheers