dreamquark-ai / tabnet

PyTorch implementation of TabNet paper : https://arxiv.org/pdf/1908.07442.pdf
https://dreamquark-ai.github.io/tabnet/
MIT License
2.65k stars 488 forks source link

Training Model on MPS not working - Mac #561

Closed Altaf-LimeChat closed 1 month ago

Altaf-LimeChat commented 1 month ago

Hi, I was trying to train model on my m1 macbook pro, but seems like the model is not using GPU which macbook provides at all, it is training on just CPU, below is the configurations I have used.


import torch
from pytorch_tabnet.augmentations import ClassificationSMOTE
from pytorch_tabnet.metrics import Metric
from pytorch_tabnet.tab_model import TabNetClassifier
from sklearn.metrics import f1_score, precision_score, recall_score

torch.set_default_device('mps')
torch.set_default_dtype(torch.float32)
TABNET_PARAMS_S = {
    "n_d": 10,
    "n_a": 10,
    "n_steps": 5,
    "n_shared": 3,
    "n_independent": 3,
    "optimizer_fn": torch.optim.Adam,
    "optimizer_params": dict(lr=2e-2),
    "scheduler_params": {"step_size": 50, "gamma": 0.9},
    "scheduler_fn": torch.optim.lr_scheduler.StepLR,
    "mask_type": "entmax",
    "device_name": "mps"
}

TABNET_PARAMS_M = {
    "n_d": 64,
    "n_a": 64,
    "n_steps": 5,
    "n_shared": 3,
    "n_independent": 3,
    "optimizer_fn": torch.optim.AdamW,
    "optimizer_params": dict(lr=3e-3),
    "mask_type": "entmax",
    "verbose": 1,
    # "device_name": "mps"
}

class Precision(Metric):
    def __init__(self):
        self._name = "precision"
        self._maximize = True

    def __call__(self, y_true, y_score):
        y_pred = y_score.argmax(axis=1)
        return precision_score(y_true, y_pred)

class Recall(Metric):
    def __init__(self):
        self._name = "recall"
        self._maximize = True

    def __call__(self, y_true, y_score):
        y_pred = y_score.argmax(axis=1)
        return recall_score(y_true, y_pred)

class F1Score(Metric):
    def __init__(self):
        self._name = "f1_score"
        self._maximize = True

    def __call__(self, y_true, y_score):
        y_pred = y_score.argmax(axis=1)
        return f1_score(y_true, y_pred)

def Classifier(X_train, Y_train, X_test, Y_test, return_classifier=False):
    clf = TabNetClassifier(**TABNET_PARAMS_S)
    max_epochs = 10
    clf.fit(
        X_train=X_train,
        y_train=Y_train,
        eval_set=[(X_train, Y_train), (X_test, Y_test)],
        eval_name=["train", "test"],
        eval_metric=["accuracy", Precision, Recall, F1Score, "auc"],
        # setting max epochs and patience to be same to avoid early stopping
        # use the model with best metrics at the end instaed of last weights
        max_epochs=max_epochs,
        patience=max_epochs,
        batch_size=1024,
        virtual_batch_size=128,
        num_workers=0,
        weights=1,
        drop_last=False,
        augmentations=ClassificationSMOTE(p=0.25),
        compute_importance=False,
    )

    return clf if return_classifier else clf.history

def Classifier_large(X_train, Y_train, X_test, Y_test, max_epochs = 50, importance = True):
    clf = TabNetClassifier(**TABNET_PARAMS_M)
    clf._default_metric = "accuracy"
    clf._default_loss = torch.nn.functional.binary_cross_entropy
    clf.device = torch.device("mps")
    clf.fit(
        X_train=X_train,
        y_train=Y_train,
        eval_set=[(X_train, Y_train), (X_test, Y_test)],
        eval_name=["train", "test"],
        eval_metric=["accuracy", Precision, Recall, F1Score, "auc"],
        # setting max epochs and patience to be same to avoid early stopping
        # use the model with best metrics at the end instaed of last weights
        max_epochs=max_epochs,
        patience=max_epochs,
        batch_size=1024,
        virtual_batch_size=128,
        num_workers=0,
        weights=1,
        drop_last=False,
        augmentations=ClassificationSMOTE(p=0.25),
        compute_importance=importance,
    )

    return clf
Optimox commented 1 month ago

What do you see when running torch.cuda.is_available() ? If you see True then tabnet is using your GPU, otherwise their might be a problem with the torch installation, you should try to install a version that returns True. I am not familiar with Apple's M1 so I can't say much more.

Optimox commented 1 month ago

Feel free to reopen with more information