Removing pre-grad memory allocation by setting the initial ret value to a nam output
Removing extra sum operations by storing nam outputs in a list and summing once with torch.stack(nam_outputs).sum(dim=0)
Removing type-based += interpretation using only torch-native addition torch.add(tensor1, tensor2)
My naive idea for moving forward would use a standard linear layer with some complicated masking to enforce additivity, but this will allocate a lot more memory than necessary.
Even when instantiating the MLP with many more parameters than the NGAM, the MLP module is about 6x-8x faster than the NGAM module. The slowdown is here: https://github.com/cnellington/Contextualized/blob/3e4fba4f9166e023b17d091d6adba70c0804525a/contextualized/modules.py#L75
Some things I've tried, with no success:
ret
value to a nam outputtorch.stack(nam_outputs).sum(dim=0)
+=
interpretation using only torch-native additiontorch.add(tensor1, tensor2)
My naive idea for moving forward would use a standard linear layer with some complicated masking to enforce additivity, but this will allocate a lot more memory than necessary.