ContinualAI / avalanche

Avalanche: an End-to-End Library for Continual Learning based on PyTorch.
http://avalanche.continualai.org
MIT License
1.75k stars 287 forks source link

make_train_dataloader discards custom collate function passed as kwarg #1531

Open niniack opened 10 months ago

niniack commented 10 months ago

Describe the bug Calling

cl_strategy.train(
                        experience,
                        eval_streams=[val_exp],
                        num_workers=4,
                        collate_fn=my_custom_collate,
                    )

should respect all of the keyword arguments I pass in. In this case, my_custom_collate is discarded.

To Reproduce For debugging, I define a custom strategy to examine what is passed into the dataloader. The make_train_dataloader function is lifted as it is from the 0.4.0 implementation

(Please note that I set breakpoints with pdb)

class CustomNaiveStrategy(Naive):
    def make_train_dataloader(
        self,
        num_workers=0,
        shuffle=True,
        pin_memory=None,
        persistent_workers=False,
        drop_last=False,
        **kwargs
    ):
        assert self.adapted_dataset is not None

        # fmt:off
        import pdb; pdb.set_trace();
        # fmt:on

        other_dataloader_args = self._obtain_common_dataloader_parameters(
            batch_size=self.train_mb_size,
            num_workers=num_workers,
            shuffle=shuffle,
            pin_memory=pin_memory,
            persistent_workers=persistent_workers,
            drop_last=drop_last,
        )

        # fmt:off
        pdb.set_trace()
        # fmt:on

        if "ffcv_args" in kwargs:
            other_dataloader_args["ffcv_args"] = kwargs["ffcv_args"]

        self.dataloader = TaskBalancedDataLoader(
            self.adapted_dataset, oversample_small_groups=True, **other_dataloader_args
        )

Expected behavior

other_dataloader_args should obey the kwargs and pass my_custom_collate along

Screenshots bug

In the screenshot above p kwargs shows the custom collate function, but that does not show up in other_dataloader_args which is what is passed onto TaskBalancedDataLoader

Additional context

I cannot immediately think of why something like other_dataloader_args.update(kwargs) is a poor idea, would love to hear thoughts.

niniack commented 10 months ago

This was supposedly fixed in #1089, or at least this is mentioned there:

Dataloading in strategies now checks if the dataset has a "collate_fn" function and uses that unless one is specified through kwargs (which takes precedence).

But, my experience above doesn't align with it. Either way, #1089 seems relevant to the conversation.

AntonioCarta commented 10 months ago

This is definitely a bug. Can you submit a PR that properly adds collate_fn to other_dataloader_args? This should be the only needed change.

niniack commented 10 months ago

Should this fix be done through updating _obtain_common_dataloader_parameters? Or is there another Avalanche style way of doing this? The hotfix of other_dataloader_args.update(kwargs) doesn't seem very Avalanche-y (but maybe i'm wrong!!)

I will also write a test to check whether kwarg collate takes precedence over dataset collate.

Feel free to assign to me, thanks

AntonioCarta commented 10 months ago

I think updating _obtain_common_dataloader_parameters is the best way.

lrzpellegrini commented 7 months ago

Was this fixed in the meantime?