Closed jasminerienecker closed 2 weeks ago
Check out this pull request on
See visual diffs & provide feedback on Jupyter Notebooks.
Powered by ReviewNB
I'd prefer to introduce a single argument like dataloader_kwargs
that gets passed to the dataloaders constructors and deprecate the num_workers_loader
argument instead. That way we could also support persistent_workers
for example.
I'd prefer to introduce a single argument like
dataloader_kwargs
that gets passed to the dataloaders constructors and deprecate thenum_workers_loader
argument instead. That way we could also supportpersistent_workers
for example.
I've updated this - the review is now essentially a direct replacement of the num_workers variables with a dataloader_kwargs dictionary (with default None)
Sorry, by deprecating I meant keeping the argument and then doing something like:
if self.num_workers_loader != 0: # value is not at its default
warnings.warn(
"The `num_workers_loader` argument is deprecated and will be removed in a future version. "
"Please provide num_workers through `dataloader_kwargs`, e.g. "
f"`dataloader_kwargs={'num_workers': {self.num_workers_loader}`",
category=FutureWarning,
)
dataloader_kwargs['num_workers'] = self.num_workers_loader
Sorry, by deprecating I meant keeping the argument and then doing something like:
if self.num_workers_loader != 0: # value is not at its default warnings.warn( "The `num_workers_loader` argument is deprecated and will be removed in a future version. " "Please provide num_workers through `dataloader_kwargs`, e.g. " f"`dataloader_kwargs={'num_workers': {self.num_workers_loader}`", category=FutureWarning, ) dataloader_kwargs['num_workers'] = self.num_workers_loader
@jmoralez ah yes I see - I've put that back in now and added the deprecation warnings to the base models class
Thanks! Sorry, I messed up the suggestion, we should do the dataloader_kwargs['num_workers'] = self.num_workers_loader
inside the if
block, otherwise we'll override them every time, instead of just when the user wants to set it.
@jmoralez makes sense - that should be updated in both cases now!
Thanks! I'm very sorry, I just realized from your changes to BaseRecurrent
that we have a similar argument for predict called data_module_kwargs
, although I don't think you can provide useful arguments to the data module through it, just batch_size
, drop_last
and shuffle_train
which aren't used during predict.
@cchallu what's the purpose of the data_module_kwargs
argument?
@jasminerienecker in the meantime, can you please revert the changes to the predict method of BaseRecurrent
? We can limit this PR to adding arguments to the training step.
@jmoralez all good - that's now been updated
This is to allow adjusting the torch pin_memory and prefetch_factor variables to optimize gpu usage.
Note: by adjusting these variables I am now able to increase GPU usage to 95% whereas with just the num_workers variable that is currently exposed to the interface, GPU usage hovers around 40-60%.