qubvel-org / segmentation_models.pytorch

Semantic segmentation models with 500+ pretrained convolutional and transformer-based backbones.
https://smp.readthedocs.io/
MIT License
9.58k stars 1.67k forks source link

smp.utils module is deprecated #782

Closed ningmenghongcha closed 1 year ago

ningmenghongcha commented 1 year ago

I am following the example cars segmentation In order to train my custom data, I have written a train.py

` if name == 'main':

ENCODER = 'resnet34'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['object']
ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multiclass segmentation
DEVICE = 'cuda'

# create segmentation model with pretrained encoder
model = smp.UnetPlusPlus(
    encoder_name=ENCODER,
    encoder_weights=ENCODER_WEIGHTS,
    classes=len(CLASSES),
    activation=ACTIVATION,
)

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

DATA_DIR = 'data/MGD/'
x_train_dir = os.path.join(DATA_DIR, 'train')
y_train_dir = os.path.join(DATA_DIR, 'trainannot')

x_valid_dir = os.path.join(DATA_DIR, 'val')
y_valid_dir = os.path.join(DATA_DIR, 'valannot')

train_dataset = Dataset(
    x_train_dir,
    y_train_dir,
    augmentation=get_training_augmentation(),
    preprocessing=get_preprocessing(preprocessing_fn),
    classes=CLASSES,
)

valid_dataset = Dataset(
    x_valid_dir,
    y_valid_dir,
    augmentation=get_validation_augmentation(),
    preprocessing=get_preprocessing(preprocessing_fn),
    classes=CLASSES,
)

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False)

loss = smp.utils.losses.DiceLoss()
metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
]
optimizer = torch.optim.Adam([
    dict(params=model.parameters(), lr=0.0001),
])

train_epoch = smp_utils.train.TrainEpoch(
    model,
    loss=loss,
    metrics=metrics,
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)
valid_epoch = smp.utils.train.ValidEpoch(
    model,
    loss=loss,
    metrics=metrics,
    device=DEVICE,
    verbose=True,
)
# train model for 40 epochs

max_score = 0

for i in range(0, 40):

    print('\nEpoch: {}'.format(i))
    train_logs = train_epoch.run(train_loader)
    # valid_logs = valid_epoch.run(valid_loader)

    # do something (save model, change lr, etc.)
    if max_score < train_logs['iou_score']:
        max_score = train_logs['iou_score']
        torch.save(model, 'checkpoints/best_model.pth')
        print('Model saved!')

    if i == 25:
        optimizer.param_groups[0]['lr'] = 1e-5
        print('Decrease decoder learning rate to 1e-5!')

` However,it shows smp.utils module is deprecated. 1686623606923 1686623641702

How to use the latest module to avoid this warning?Maybe you can update the jupyter notebook. Thank you for your attention.

MoriKen254 commented 1 year ago

I'm facing the same situation too. I'm also curious about this.

chefkrym commented 1 year ago

@MoriKen254 @ningmenghongcha

Try this...

` import torch import numpy as np import segmentation_models_pytorch as smp import segmentation_models_pytorch.utils.metrics

ENCODER = 'mit_b4' ENCODER_WEIGHTS = 'imagenet' CLASSES = ['1'] ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multiclass segmentation DEVICE = 'cuda'

train_dataset = Dataset( x_train_dir, y_train_dir, augmentation=get_training_augmentation(), preprocessing=get_preprocessing(preprocessing_fn), classes=CLASSES, )

valid_dataset = Dataset( x_valid_dir, y_valid_dir, preprocessing=get_preprocessing(preprocessing_fn), classes=CLASSES, )

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=0) valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=0)

Dice/F1 score - https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient

IoU/Jaccard score - https://en.wikipedia.org/wiki/Jaccard_index

from segmentation_models_pytorch import utils loss = utils.losses.BCELoss() metrics = [ utils.metrics.IoU(threshold=0.2), utils.metrics.Fscore(), utils.metrics.Recall(), utils.metrics.Precision(), ]

optimizer = torch.optim.Adam([ dict(params=model.parameters(), lr=0.0001), ])

`

MoriKen254 commented 1 year ago

@chefkrym

thank you so much! it looks like smp.utils is used isnt it?

github-actions[bot] commented 1 year ago

This issue is stale because it has been open 60 days with no activity. Remove stale label or comment or this will be closed in 7 days.

github-actions[bot] commented 1 year ago

This issue was closed because it has been stalled for 7 days with no activity.