Mambular is a Python package that brings the power of Mamba architectures to tabular data, offering a suite of deep learning models for regression, classification, and distributional regression tasks. This includes models like Mambular, FT-Transformer, TabTransformer and tabular ResNets.
This pull request introduces several significant changes to the mambular project, focusing on adding new modules, implementing utility functions, and simplifying the normalization layer selection in base models. The most important changes include the addition of the NODE model and the refactoring of normalization layer selection.
mambular/arch_utils/node_utils.py: Introduced the ODST class for differentiable decision tree models and the DenseBlock class for stacking layers of decision trees.
This pull request introduces several significant changes to the
mambular
project, focusing on adding new modules, implementing utility functions, and simplifying the normalization layer selection in base models. The most important changes include the addition of the NODE model and the refactoring of normalization layer selection.New Modules and Functions:
mambular/arch_utils/data_aware_initialization.py
: Added theModuleWithInit
class, which provides a base class for PyTorch modules with data-aware initialization on the first batch.mambular/arch_utils/layer_utils/sparsemax.py
: Implemented the sparsemax function and its backward pass, providing a sparse alternative to softmax.mambular/arch_utils/node_utils.py
: Introduced theODST
class for differentiable decision tree models and theDenseBlock
class for stacking layers of decision trees.mambular/arch_utils/numpy_utils.py
: Added thecheck_numpy
function to ensure a tensor is converted to a NumPy array.Refactoring:
mambular/base_models/ft_transformer.py
: Simplified the normalization layer selection by using theget_normalization_layer
function. [1] [2]mambular/base_models/mlp.py
: Refactored the normalization layer selection to use theget_normalization_layer
function, reducing redundancy. [1] [2]