the lipschitz_projection was failing when a different cuda device is specified (like torch.device('cuda:1')).
This fix transfers the lip_reg tensor always to the correct device.
Please test if it works for you!
This should not interfere with our existing code due to the device being inferred from the weights tensor directly. So no new arguments need to be passed.
Development:
[ ] Add tests
Checks:
nosetests
Hey,
the
lipschitz_projection
was failing when a different cuda device is specified (liketorch.device('cuda:1')
). This fix transfers thelip_reg
tensor always to the correct device.Please test if it works for you! This should not interfere with our existing code due to the
device
being inferred from the weights tensor directly. So no new arguments need to be passed.