DeepLabCut / DeepLabCut

Official implementation of DeepLabCut: Markerless pose estimation of user-defined features with deep learning for all animals incl. humans
http://deeplabcut.org
GNU Lesser General Public License v3.0
4.67k stars 1.66k forks source link

DLC PyTorch. Augmentation inconsistencies. #2753

Closed YankoFelipe closed 4 weeks ago

YankoFelipe commented 1 month ago

Is there an existing issue for this?

Bug description

Hello DeepLabCut!

I've been observing little variation of my training results when I started adding augmentations in my project (2D/single animal/ PyTorch) using the documentation in https://deeplabcut.github.io/DeepLabCut/docs/recipes/pose_cfg_file_breakdown.html and I found some differences with the actual code in deeplabcut.pose_estimation_pytorch.data.transforms.build_transforms such as:

if rotation is not None:
    rotation = (-rotation, rotation)
if translation is not None:
    translation = (0, translation)

It's a bit misleading that rotation is applied symmetrically but translation isn't.

I understand this is still a beta functionality but I think it would be good to consider that all these differences may bring issues to people migrating from their 2.x projects.

Besides, even after making the modifications to match the values expected in the code (see attached log) and trying different values of augmentations, the evolution of my losses (train/val) is more or less the same (increasing the training epoch has only showed overfitting for my project).

Operating System

SUSE Linux 15.5

DeepLabCut version

dlc version 3.0.0rc4

DeepLabCut mode

single animal

Device type

Nvidia A100

Steps To Reproduce

config.yaml (except video_sets because it's too large)

Task: mime
scorer: mime
date: Aug9
multianimalproject: false
identity:
engine: pytorch
bodyparts:
- EyeInner
- EyeOuter
- EyeTop
- EyeBottom
- PupilInner
- PupilOuter
- PupilTop
- PupilBottom
- Ear1
- Ear2
- Ear3
- Ear4
- Ear5
- Ear6
- NoseTip
- NoseTop
- NoseBottom
- PadFront
- PadTop
- PadBottom
- MouthUp
- MouthLow1
- MouthLow2
- MouthLow3
start: 0
stop: 1
numframes2pick: 5
# Plotting configuration
skeleton:
- - EyeInner
  - EyeOuter
  - EyeTop
  - EyeBottom
- - PupilInner
  - PupilOuter
  - PupilTop
  - PupilBottom
- - Ear1
  - Ear2
  - Ear3
  - Ear4
  - Ear5
  - Ear6
- - NoseTip
  - NoseTop
  - NoseBottom
- - PadFront
  - PadTop
  - PadBottom
- - MouthUp
  - MouthLow1
  - MouthLow2
  - MouthLow3
skeleton_color: black
pcutoff: 0.6
dotsize: 12
alphavalue: 0.7
colormap: rainbow
# Training,Evaluation and Analysis configuration
TrainingFraction:
- 0.85
iteration: 2
default_net_type: resnet_50
default_augmenter: default
snapshotindex: -1
detector_snapshotindex: -1
batch_size: 8
detector_batch_size: 1
# Cropping Parameters (for analysis and outlier frame detection)
cropping: false
#if cropping is true for analysis, then set the values here:
x1: 0
x2: 640
y1: 277
y2: 624
# Refinement configuration (parameters from annotation dataset configuration also relevant in this stage)
corner2move2:
- 50
- 50
move2corner: true
SuperAnimalConversionTables:

pose_cfg.yaml

all_joints:
- - 0
- - 1
- - 2
- - 3
- - 4
- - 5
- - 6
- - 7
- - 8
- - 9
- - 10
- - 11
- - 12
- - 13
- - 14
- - 15
- - 16
- - 17
- - 18
- - 19
- - 20
- - 21
- - 22
- - 23
all_joints_names:
- EyeInner
- EyeOuter
- EyeTop
- EyeBottom
- PupilInner
- PupilOuter
- PupilTop
- PupilBottom
- Ear1
- Ear2
- Ear3
- Ear4
- Ear5
- Ear6
- NoseTip
- NoseTop
- NoseBottom
- PadFront
- PadTop
- PadBottom
- MouthUp
- MouthLow1
- MouthLow2
- MouthLow3
alpha_r: 0.02
apply_prob: 0.5
augmentationprobability: 0.55
batch_size: 4
contrast:
  clahe: true
  claheratio: 0.1
  histeq: true
  histeqratio: 0.1
convolution:
  edge: false
  emboss:
    alpha:
    - 0.0
    - 1.0
    strength:
    - 0.5
    - 1.5
  embossratio: 0.1
  sharpen: false
  sharpenratio: 0.3
cropratio: 0.4
dataset: training-datasets/iteration-2/UnaugmentedDataSet_mimeAug9/mime_mime85shuffle18.mat
dataset_type: albumentations
decay_steps: 30000
display_iters: 1000
engine: pytorch
fliplr: true
global_scale: 0.8
init_weights: lib/python3.10/site-packages/deeplabcut
intermediate_supervision: false
intermediate_supervision_layer: 12
location_refinement: true
locref_huber_loss: true
locref_loss_weight: 0.05
locref_stdev: 7.2801
lr_init: 0.0005
max_input_size: 1500
metadataset: training-datasets/iteration-2/UnaugmentedDataSet_mimeAug9/Documentation_data-mime_85shuffle18.pickle
min_input_size: 64
mirror: false
multi_stage: false
multi_step:
- - 0.005
  - 10000
- - 0.02
  - 430000
- - 0.002
  - 730000
- - 0.001
  - 1030000
net_type: hrnet_w48
num_joints: 24
pairwise_huber_loss: false
pairwise_predict: false
partaffinityfield_predict: false
pos_dist_thresh: 17
project_path: mime/dlc/mime-mime-2024-08-09
rotation: 25
rotratio: 0.4
save_iters: 50000
scale_jitter_lo: 0.5
scale_jitter_up: 1.25

pytorch_config.yaml

augmentationprobability: 0.55
batch_size: 4
data:
  colormode: RGB
  inference:
    auto_padding:
      pad_height_divisor: 32
      pad_width_divisor: 32
    hflip: true
    normalize_images: true
  train:
    affine:
      p: 0.5
      rotation: 90
      scaling:
      - 1.0
      - 1.0
      translation: 50
    collate:
      max_scale: 1.0
      max_short_side: 1152
      min_scale: 0.4
      min_short_side: 128
      multiple_of: 32
      to_square: false
      type: ResizeFromDataSizeCollate
    covering: false
    gaussian_noise: 50
    hflip: true
    hist_eq: true
    motion_blur: false
    normalize_images: true
device: auto
fliplr: true
metadata:
  bodyparts:
  - EyeInner
  - EyeOuter
  - EyeTop
  - EyeBottom
  - PupilInner
  - PupilOuter
  - PupilTop
  - PupilBottom
  - Ear1
  - Ear2
  - Ear3
  - Ear4
  - Ear5
  - Ear6
  - NoseTip
  - NoseTop
  - NoseBottom
  - PadFront
  - PadTop
  - PadBottom
  - MouthUp
  - MouthLow1
  - MouthLow2
  - MouthLow3
  individuals:
  - animal
  pose_config_path: 
    mime-mime-2024-08-09/dlc-models-pytorch/iteration-2/mimeAug9-trainset85shuffle18/train/pose_cfg.yaml
  project_path: mime/dlc/mime-mime-2024-08-09
  unique_bodyparts: []
  with_identity:
method: bu
model:
  backbone:
    freeze_bn_stats: false
    freeze_bn_weights: false
    increased_channel_count: false
    interpolate_branches: false
    model_name: hrnet_w48
    type: HRNet
  backbone_output_channels: 48
  freeze_bn_stats: false
  heads:
    bodypart:
      criterion:
        heatmap:
          type: WeightedMSECriterion
          weight: 1.0
        locref:
          type: WeightedHuberCriterion
          weight: 0.05
      heatmap_config:
        channels:
        - 48
        - 24
        kernel_size:
        - 3
        strides:
        - 2
      locref_config:
        channels:
        - 48
        - 48
        kernel_size:
        - 3
        strides:
        - 2
      predictor:
        apply_sigmoid: false
        clip_scores: true
        location_refinement: true
        locref_std: 7.2801
        type: HeatmapPredictor
      target_generator:
        generate_locref: true
        heatmap_mode: KEYPOINT
        locref_std: 7.2801
        num_heatmaps: 24
        pos_dist_thresh: 17
        type: HeatmapGaussianGenerator
      type: HeatmapHead
      weight_init: normal
net_type: hrnet_w48
runner:
  eval_interval: 10
  gpus:
  key_metric: test.mAP
  key_metric_asc: true
  optimizer:
    params:
      lr: 0.0001
    type: AdamW
  scheduler:
    params:
      lr_list:
      - - 1e-05
      - - 1e-06
      milestones:
      - 160
      - 190
    type: LRListScheduler
  snapshots:
    max_snapshots: 5
    save_epochs: 25
    save_optimizer_state: false
  type: PoseTrainingRunner
train_settings:
  batch_size: 4
  dataloader_pin_memory: true
  dataloader_workers: 0
  display_iters: 1000
  epochs: 250
  seed: 306870

Relevant log output

Loading pretrained weights from Hugging Face hub (timm/resnet101.a1h_in1k)
[timm/resnet101.a1h_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
lib/python3.10/site-packages/deeplabcut/pose_estimation_pytorch/data/transforms.py:68: UserWarning: Be careful! Do not train pose models with horizontal flips if you have symmetric keypoints!
  warnings.warn(
Data Transforms:
  Training:   Compose([
  HorizontalFlip(always_apply=False, p=0.5),
  Affine(always_apply=False, p=0.5, interpolation=1, mask_interpolation=0, cval=0, mode=0, scale={'x': (1.0, 1.0), 'y': (1.0, 1.0)}, translate_percent=None, translate_px={'x': (0, 50), 'y': (0, 50)}, rotate=(-90, 90), fit_output=False, shear={'x': (0.0, 0.0), 'y': (0.0, 0.0)}, cval_mask=0, keep_ratio=True, rotate_method='largest_box'),
  Equalize(always_apply=False, p=0.5, mode='cv', by_channels=True, mask=None, mask_params=()),
  GaussNoise(always_apply=False, p=0.5, var_limit=(0, 2500), per_channel=True, mean=0),
  Normalize(always_apply=False, p=1.0, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0),
], p=1.0, bbox_params={'format': 'coco', 'label_fields': ['bbox_labels'], 'min_area': 0.0, 'min_visibility': 0.0, 'min_width': 0.0, 'min_height': 0.0, 'check_each_transform': True}, keypoint_params={'format': 'xy', 'label_fields': ['class_labels'], 'remove_invisible': False, 'angle_in_degrees': True, 'check_each_transform': True}, additional_targets={}, is_check_shapes=True)
  Validation: Compose([
  HorizontalFlip(always_apply=False, p=0.5),
  Normalize(always_apply=False, p=1.0, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0),
], p=1.0, bbox_params={'format': 'coco', 'label_fields': ['bbox_labels'], 'min_area': 0.0, 'min_visibility': 0.0, 'min_width': 0.0, 'min_height': 0.0, 'check_each_transform': True}, keypoint_params={'format': 'xy', 'label_fields': ['class_labels'], 'remove_invisible': False, 'angle_in_degrees': True, 'check_each_transform': True}, additional_targets={}, is_check_shapes=True)
Using custom collate function: {'max_scale': 1.0, 'max_short_side': 1152, 'min_scale': 0.4, 'min_short_side': 128, 'multiple_of': 32, 'to_square': False, 'type': 'ResizeFromDataSizeCollate'}
Using 285 images and 51 for testing

Starting pose model training...

Anything else?

I'd like to know if there's a way to debug and check that the augmentations are applied during training time in DLC Pytorch.

I've trained using the same version (3.0.0rc4) in Windows 10 (22H2) with a RTX 4090 with similar results.

I'm looking forward to your comments :)

Code of Conduct

n-poulsen commented 1 month ago

Hey @YankoFelipe, thanks for the feedback! One first general comment: in DeepLabCut 3.0, we switched to the albumentations package for image augmentation. This means that while most of the transforms available in DeepLabCut < 3 are available as well, there are some minor differences. Information about the DeepLabCut 3.0 augmentations are available here: deeplabcut.github.io/DeepLabCut/docs/pytorch/pytorch_config

Image Augmentations

Horizontal flips: As you said, there's a naming change here from fliplr to hflip. The docs here are a bit out of data (I'll make sure I update them). You can add horizontal flipping as an augmentation in different ways (with your bodyparts, I'm guessing that hflip: true is the correct choice and all you need, as it doesn't look like you have any symmetric bodyparts):

# The first three here should only be used if you don't have symmetric keypoints (e.g. `leftEye`, `rightEye`) or 
# are used for an object detector

# random flip with probability 50%
hflip: true
# random flip with probability 25%
hflip: 0.25
# random flip with probability 25%
hflip:
  p: 0.25

# If you do have symmetric keypoints, you need to indicate them in the hflip configuration
#   E.g. if your bodyparts are ["nose", "rightEye", "rightEar", "leftEye", "leftEar"]
hflip:
  p: 0.25
  symmetries:
  - - 1
    - 3
  - - 2
    - 4

One element I would edit would be removing the hflip: true during inference; currently that's just randomly images when evaluating your model, which means you won't obtain the "true" performance of your model on the test set.

Translation is not applied symmetrically: That's indeed a bug in our code, and the values should be sampled symmetrically. I'll fix this in an upcoming PR.

Checking which augmentations are being used in training

That's also something I was curious about when developing, which is why in your logs the Albumentations transforms for training and inference are printed before training starts.

Data Transforms:
  Training:  Compose(
    [
      HorizontalFlip(always_apply=False, p=0.5),
      Affine(always_apply=False, p=0.5, interpolation=1, mask_interpolation=0, cval=0, mode=0, scale={'x': (1.0, 1.0), 'y': (1.0, 1.0)}, translate_percent=None, translate_px={'x': (0, 50), 'y': (0, 50)}, rotate=(-90, 90), fit_output=False, shear={'x': (0.0, 0.0), 'y': (0.0, 0.0)}, cval_mask=0, keep_ratio=True, rotate_method='largest_box'),
      Equalize(always_apply=False, p=0.5, mode='cv', by_channels=True, mask=None, mask_params=()),
      GaussNoise(always_apply=False, p=0.5, var_limit=(0, 2500), per_channel=True, mean=0),
      Normalize(always_apply=False, p=1.0, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0),
    ],
    p=1.0,
    bbox_params={'format': 'coco', 'label_fields': ['bbox_labels'], 'min_area': 0.0, 'min_visibility': 0.0, 'min_width': 0.0, 'min_height': 0.0, 'check_each_transform': True},
    keypoint_params={'format': 'xy', 'label_fields': ['class_labels'], 'remove_invisible': False, 'angle_in_degrees': True, 'check_each_transform': True},
    additional_targets={},
    is_check_shapes=True
  )
  Validation: Compose(
    [
      HorizontalFlip(always_apply=False, p=0.5),
      Normalize(always_apply=False, p=1.0, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0),
    ],
    ...
  )

If you're interested in seeing what the augmented data actually looks like, the following code snippet might help (which is basically what the train_network method does to load data):

import matplotlib.pyplot as plt
import torch
import torchvision.transforms as transforms
from deeplabcut.pose_estimation_pytorch.data import build_transforms, DLCLoader
from deeplabcut.pose_estimation_pytorch.task import Task

loader = DLCLoader(
    config="/Users/niels/Documents/upamathis/datasets2/test/trimice-dlc-2021-06-22/config.yaml",
    shuffle=1,
    trainset_index=0,
)

transform = build_transforms(loader.model_cfg["data"]["train"])
transform_inf = build_transforms(loader.model_cfg["data"]["inference"])

pose_task = Task(loader.model_cfg["method"])
train_dataset = loader.create_dataset(transform=transform, mode="train", task=pose_task)
valid_dataset = loader.create_dataset(transform= transform_inf, mode="test", task=pose_task)

print(f"Number of training images:  {len(train_dataset)}")
print(f"Number of validation images: {len(valid_dataset)}")

# Needed so when we plot the image, the color channels aren't normalized
denormalize = transforms.Compose(
    [
        transforms.Normalize(mean=[0, 0, 0], std=[1/0.229, 1/0.224, 1/0.225]),
        transforms.Normalize(mean=[-0.485, -0.456, -0.406], std=[1, 1, 1]),
    ]
)

def plot_augmented_image(dataset, index):
    sample_train_data = dataset[index]
    train_image = sample_train_data["image"]

    # image was normalized to ImageNet means, so it needs to be un-normalized to have the correct visual appearance
    img = denormalize(torch.tensor(train_image))
    # Image is (C, H, W) and we need it to be (H, W, C) to plot it
    img = img.numpy().transpose((1, 2, 0))

    fig, ax = plt.subplots(1)
    ax.imshow(img)
    plt.show()

plot_augmented_image(train_dataset, 0)
plot_augmented_image(train_dataset, 0)

As there are random augmentations in your images, calling plot_augmented_image multiple times with the same image index should lead to different transformations being seen.