CAREamics / careamics

A deep-learning library for N2V and friends
https://careamics.github.io/
BSD 3-Clause "New" or "Revised" License
30 stars 6 forks source link

Passing dataloader `num_workers` param causes bad results w/o `shuffle=True` #258

Open melisande-c opened 4 weeks ago

melisande-c commented 4 weeks ago

When passing parameters to the dataloader in the TrainDataModule it may prevent the dataloader from shuffling the data. A fix is to explicitly pass shuffle=True. After some further investigation an issue should likely be raised on Pytorch Lightning.

Examples initialising the TrainDataModule that will give bad results because the shuffling is prevented somehow.

config.data_config.dataloader_params = {"num_workers": 4}
train_data_module = TrainDataModule(
    config.data_config, train_data=train_path, val_data=val_path
)

or

train_data_module = create_train_datamodule(
    train_data=train_path, 
    val_data= val_path, 
    data_type="tiff",
    patch_size=(64, 64),
    axes="SYX",
    batch_size=16,
    dataloader_params={"num_workers": 4},
)

To specify the number of workers and have shuffling use dataloader_params={"num_workers": 4, "shuffle"=True}

tlambert03 commented 3 weeks ago

very naive question (cause I'm new here :wave:), i'm curious why you would ever get data shuffling (whether or not you manually assign .dataloader_params) without explicitly passing shuffle=True. The default for torch.utils.data.DataLoader is shuffle=None (i.e. False), and between assigning dataloader_params here:

https://github.com/CAREamics/careamics/blob/52cf7b9cc1a328f4d8514ef5f05646c2c04c669e/src/careamics/lightning/train_data_module.py#L264-L266

and creating the Dataloader here:

https://github.com/CAREamics/careamics/blob/52cf7b9cc1a328f4d8514ef5f05646c2c04c669e/src/careamics/lightning/train_data_module.py#L440-L452

i don't see any internal careamics logic to set shuffle to True?

melisande-c commented 3 weeks ago

Hi @tlambert03, after some investigation it seems that indeed by default train dataloader is not shuffled. I originally came to the conclusion that it was shuffled because not passing any dataloader parameters has good results, passing {"num_workers": 4} has bad results but {"num_workers": 4, "shuffle": True} has again good results. However, I have now saved the input batches during training and it seems the data is not shuffled unless "shuffle"=True, so something strange is going on.

In the lightning Trainer docs it says use_distributed_sampler is by default True. It mentions: "By default, it will add shuffle=True for the train sampler and shuffle=False for validation/test/predict samplers." This may be having some effect but it is difficult to work out what is going on in the lightning source code.

Experiments

lightning component initialisation

For each of the following experiments this is how I initialise the different lightning components. The only difference being changing the dataloader_params. I also turned off augmentations by passing the empty list to ensure that wasn't causing any differences.

num_epochs = 10
config = create_n2v_configuration(
    experiment_name="<RUN NAME>",
    data_type="tiff",
    axes="SYX",
    patch_size=(64, 64),
    batch_size=64,
    num_epochs=num_epochs,
    augmentations=[],
    dataloader_params={"num_workers": 4, "shuffle": True}
)
print(config.data_config.transforms)
print(config.data_config.dataloader_params)
config.algorithm_config.optimizer.parameters["lr"] = 1e-4

lightning_module = create_careamics_module(
    algorithm=config.algorithm_config.algorithm,
    loss=config.algorithm_config.loss,
    architecture=config.algorithm_config.model.architecture,
    optimizer_parameters=config.algorithm_config.optimizer.parameters,
)

train_data_module = create_train_datamodule(
    train_data=train_path,
    val_data=val_path,
    data_type=config.data_config.data_type,
    patch_size=config.data_config.patch_size,
    transforms=config.data_config.transforms,
    axes=config.data_config.axes,
    batch_size=config.data_config.batch_size,
    dataloader_params=config.data_config.dataloader_params
)

checkpoint_callback = ModelCheckpoint(
    dirpath=Path(__file__).parent / "checkpoints",
    filename=config.experiment_name,
    **config.training_config.checkpoint_callback.model_dump(),
)

n_batches = 5
save_dloader_callback = SaveDataloaderOutputs(n_batches=n_batches)

trainer = Trainer(
    max_epochs=config.training_config.num_epochs,
    precision=config.training_config.precision,
    max_steps=config.training_config.max_steps,
    check_val_every_n_epoch=config.training_config.check_val_every_n_epoch,
    enable_progress_bar=config.training_config.enable_progress_bar,
    accumulate_grad_batches=config.training_config.accumulate_grad_batches,
    gradient_clip_val=config.training_config.gradient_clip_val,
    gradient_clip_algorithm=config.training_config.gradient_clip_algorithm,
    callbacks=[
        checkpoint_callback,
        HyperParametersCallback(config),
        save_dloader_callback,
    ],
    default_root_dir=Path(__file__).parent,
    logger=WandbLogger(
        name=config.experiment_name,
        save_dir=Path(__file__).parent / Path(config.experiment_name) / "logs",
    ),
)
trainer.fit(model=lightning_module, datamodule=train_data_module)

Experiment: No dataloader params

Seems to work fine, below is the prediction output no_dloaderparams_no_transforms

Experiment: dataloader params {"num_workers": 4}

Very bad results

num_workers-4_no_transforms

Experiment: dataloader params {"num_workers": 4, "shuffle": True}

Results look good again

num_workers-4_shuffle-True_no_transforms

Investigation

Looking at the loss and validation curves, it seems when dataloader_params={"num_workers": 4} the model is overfitting somehow since the training loss gets much lower than the other runs.

W B Chart 29_10_2024, 10_51_14 W B Chart 29_10_2024, 10_51_35

However, when I save the batches during training for the experiment with no dataloader params they are in the same order for each epoch 🤷‍♀️. Only in the shuffle=True case are the batches shuffled.

no_dloaderparams_no_transforms

The safest thing to do of course is enforce shuffling to be True, but it would be good to get to the bottom of this.