Open greentfrapp opened 4 years ago
When I tried (and failed :rofl: ) building a Pytorch port for Lucid, I did the following to implement the redirected relu, in the optvis.render_vis
function: conditionally open some contexts with ExitStack
where I override relu with redirected relu, close the contexts at epoch 15. Some code here:
images = []
try:
with ExitStack() as stack:
if relu_gradient_override:
from limpid.misc.redirected_relu_grad import redirect_relu_F, redirect_relu_nn
# when entering these contexts, the relu becomes redirected
stack.enter_context(redirect_relu_F())
stack.enter_context(redirect_relu_nn())
model = model()
for epoch in range(n_epochs):
print(epoch)
optimizer.zero_grad()
out = model(image)
loss = objective_f(out)
loss.backward()
optimizer.step()
if epoch in thresholds:
if verbose:
print('Epoch: [{}/{}] Loss: {:.4f}'.format(epoch, n_epochs, loss))
images.append(image)
if epoch == 15:
# stop using redirected relu grad trick at 16th epoch - closes the contexts previously opened
print('closing stack')
stack.close()
where the context managers are defined like this:
@contextmanager
def redirect_relu_F():
temp = getattr(_F, 'relu')
setattr(_F, 'relu', lambda x, inplace: _redirected_relu_func(x, inplace))
yield
setattr(_F, 'relu', temp)
@contextmanager
def redirect_relu_nn():
temp = getattr(_torch.nn, 'ReLU')
setattr(_torch.nn, 'ReLU', lambda inplace: RedirectedReLU())
yield
setattr(_torch.nn, 'ReLU', temp)
class RedirectedReluFunction(_Function):
@staticmethod
def forward(ctx, input, inplace=False):
print('forward')
ctx.save_for_backward(input)
if inplace:
output = _torch.relu_(input)
else:
output = _torch.relu(input)
return output
@staticmethod
def backward(ctx, grad_output):
input = ctx.saved_tensors[0]
grad_input = None
if ctx.needs_input_grad[0]:
# correct gradient
grad_input = (grad_output > 0).float()
# now where the gradient is zero, modify it to push it
# where it becomes positive tf.zeros_like(grad), grad
redirected_grad_input = _torch.where((input < 0) | (grad_input > 0),
_torch.zeros(grad_input.size()), grad_input)
# only use redirected gradient where nothing got through original gradient
grad_input_reshaped = grad_input.view(grad_input.size(0), -1)
grad_mag = _torch.norm(grad_input_reshaped, dim=1)
grad_input = _torch.where(grad_mag > 0., grad_input, redirected_grad_input)
# gradient wrt inplace variable is always None
return grad_input, None
_redirected_relu_func = RedirectedReluFunction.apply
class RedirectedReLU(_torch.nn.Module):
def __init__(self, inplace=False):
super(RedirectedReLU, self).__init__()
self.inplace = inplace
def forward(self, input):
return _redirected_relu_func(input, inplace=self.inplace)
I'm not sure the whole implementation is correct, but you get the gist.
Oh yes, I just recently looked into context managers and was wondering if that might work for this! Thank you @iacolippo for the tip! I'll take a closer look soon and hopefully get around to implementing this.
I don't have a lot of time on my hands, but happy to help for feedback and/or minor tasks if needed. Cheers
To quote from redirected_relu_grad.py in the original Lucid library:
Lucid uses tensorflow, which allows for gradient overrides with gradient_override_map (although Lucid overrides that with their own implementation). It is also possible to keep track of the global step in tensorflow, and this is used in Lucid to make the "gradient fix" temporary (see above).
In comparison, Lucent implements a hacky workaround that is much less sophisticated.
We simply replace the ReLU function with our own RedirectedReLU, which has a modified
backward
method. When the gradient should be 0 (because of negative output), we simply scale the gradient by 0.1 and let it through. See here for the exact implementation.We do this at the model initialization stage only for the InceptionV1 model and we never switch off the redirected gradient. I suspect that not switching it off is not as bad as we might imagine, because we are updating the input values instead of the model weights. In any case, this seems to work fine so far, but I would really prefer a more principled approach.
To be frank, I haven't spent too much time thinking about this with torch. But here are the main elements of a better fix, primarily following Lucid's implementation:
backward
function?)Questions and discussions welcomed!