cnellington / Contextualized

An SKLearn-style toolbox for estimating and analyzing models, distributions, and functions with context-specific parameters.
http://contextualized.ml/
GNU General Public License v3.0
65 stars 9 forks source link

PyTorch NGAM is very slow #118

Open cnellington opened 2 years ago

cnellington commented 2 years ago

Even when instantiating the MLP with many more parameters than the NGAM, the MLP module is about 6x-8x faster than the NGAM module. The slowdown is here: https://github.com/cnellington/Contextualized/blob/3e4fba4f9166e023b17d091d6adba70c0804525a/contextualized/modules.py#L75

Some things I've tried, with no success:

My naive idea for moving forward would use a standard linear layer with some complicated masking to enforce additivity, but this will allocate a lot more memory than necessary.