Nixtla / neuralforecast

Scalable and user friendly neural :brain: forecasting algorithms.
https://nixtlaverse.nixtla.io/neuralforecast
Apache License 2.0
3.13k stars 362 forks source link

[FEAT] support providing DataLoader arguments to optimize GPU usage #1186

Closed jasminerienecker closed 2 weeks ago

jasminerienecker commented 1 month ago

This is to allow adjusting the torch pin_memory and prefetch_factor variables to optimize gpu usage.

Note: by adjusting these variables I am now able to increase GPU usage to 95% whereas with just the num_workers variable that is currently exposed to the interface, GPU usage hovers around 40-60%.

review-notebook-app[bot] commented 1 month ago

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

CLAassistant commented 1 month ago

CLA assistant check
All committers have signed the CLA.

jmoralez commented 1 month ago

I'd prefer to introduce a single argument like dataloader_kwargs that gets passed to the dataloaders constructors and deprecate the num_workers_loader argument instead. That way we could also support persistent_workers for example.

jasminerienecker commented 1 month ago

I'd prefer to introduce a single argument like dataloader_kwargs that gets passed to the dataloaders constructors and deprecate the num_workers_loader argument instead. That way we could also support persistent_workers for example.

I've updated this - the review is now essentially a direct replacement of the num_workers variables with a dataloader_kwargs dictionary (with default None)

jmoralez commented 1 month ago

Sorry, by deprecating I meant keeping the argument and then doing something like:

if self.num_workers_loader != 0:  # value is not at its default
    warnings.warn(
        "The `num_workers_loader` argument is deprecated and will be removed in a future version. "
        "Please provide num_workers through `dataloader_kwargs`, e.g. "
        f"`dataloader_kwargs={'num_workers': {self.num_workers_loader}`",
        category=FutureWarning,
    )
dataloader_kwargs['num_workers'] = self.num_workers_loader
jasminerienecker commented 1 month ago

Sorry, by deprecating I meant keeping the argument and then doing something like:


if self.num_workers_loader != 0:  # value is not at its default

    warnings.warn(

        "The `num_workers_loader` argument is deprecated and will be removed in a future version. "

        "Please provide num_workers through `dataloader_kwargs`, e.g. "

        f"`dataloader_kwargs={'num_workers': {self.num_workers_loader}`",

        category=FutureWarning,

    )

dataloader_kwargs['num_workers'] = self.num_workers_loader

@jmoralez ah yes I see - I've put that back in now and added the deprecation warnings to the base models class

jmoralez commented 4 weeks ago

Thanks! Sorry, I messed up the suggestion, we should do the dataloader_kwargs['num_workers'] = self.num_workers_loader inside the if block, otherwise we'll override them every time, instead of just when the user wants to set it.

jasminerienecker commented 4 weeks ago

@jmoralez makes sense - that should be updated in both cases now!

jmoralez commented 4 weeks ago

Thanks! I'm very sorry, I just realized from your changes to BaseRecurrent that we have a similar argument for predict called data_module_kwargs, although I don't think you can provide useful arguments to the data module through it, just batch_size, drop_last and shuffle_train which aren't used during predict.

@cchallu what's the purpose of the data_module_kwargs argument?

jmoralez commented 4 weeks ago

@jasminerienecker in the meantime, can you please revert the changes to the predict method of BaseRecurrent? We can limit this PR to adding arguments to the training step.

jasminerienecker commented 3 weeks ago

@jmoralez all good - that's now been updated