catalyst-team / catalyst

Accelerated deep learning R&D
https://catalyst-team.com
Apache License 2.0
3.28k stars 388 forks source link

`DataParallelEngine.prepare_model` missing `device_placement` kwarg. #1448

Closed podestplatz closed 1 year ago

podestplatz commented 1 year ago

🐛 Bug Report

DataParallelEngine inherits from accelerator.Accelerator which defines the prepare_model function with the kwarg device_placement. DataParallelEngine overrides this function without this kwarg causing an exception in accelerator.Accelerator.prepare_one() because prepare_one calls self.prepare_model in line 741 with device_placement=device_placement. This call gets redirected to DataParallelEngine.prepare_model, which would then call its super implementation, would it not receive the unexpected kwarg.

How To Reproduce

Steps to reproduce the behavior:

  1. Run the minimal code example below.

The stack trace I get on my machine is:

~/miniconda3/envs/detekt/lib/python3.9/site-packages/accelerate/accelerator.py:224: FutureWarning: `fp16=True` is deprecated and will be removed in version 0.15.0 of 🤗 Accelerate. Use `mixed_precision='fp16'` instead.
  warnings.warn(
╭───────────────────── Traceback (most recent call last) ──────────────────────╮
│ ~/.config/JetBrains/PyCharm2023.1/scratches/scratch_4.py:21 in               │
│ <module>                                                                     │
│                                                                              │
│   18                                                                         │
│   19 # model training                                                        │
│   20 runner = dl.SupervisedRunner()                                          │
│ ❱ 21 runner.train(                                                           │
│   22 │   model=model,                                                        │
│   23 │   engine=DataParallelEngine(fp16=True),                               │
│   24 │   criterion=criterion,                                                │
│                                                                              │
│ ~/miniconda3/envs/detekt/lib/python3.9/site-packages/catalyst/ru             │
│ nners/runner.py:377 in train                                                 │
│                                                                              │
│   374 │   │   self._profile = profile                                        │
│   375 │   │   self._load_best_on_end = load_best_on_end                      │
│   376 │   │   # run                                                          │
│ ❱ 377 │   │   self.run()                                                     │
│   378 │                                                                      │
│   379 │   @torch.no_grad()                                                   │
│   380 │   def predict_batch(self, batch: Mapping[str, Any], **kwargs) -> Map │
│                                                                              │
│ ~/miniconda3/envs/detekt/lib/python3.9/site-packages/catalyst/co             │
│ re/runner.py:422 in run                                                      │
│                                                                              │
│   419 │   │   │   self._run()                                                │
│   420 │   │   except (Exception, KeyboardInterrupt) as ex:                   │
│   421 │   │   │   self.exception = ex                                        │
│ ❱ 422 │   │   │   self._run_event("on_exception")                            │
│   423 │   │   return self                                                    │
│   424                                                                        │
│   425                                                                        │
│                                                                              │
│ ~/miniconda3/envs/detekt/lib/python3.9/site-packages/catalyst/co             │
│ re/runner.py:365 in _run_event                                               │
│                                                                              │
│   362 │   │   for callback in self.callbacks.values():                       │
│   363 │   │   │   getattr(callback, event)(self)                             │
│   364 │   │   if is_str_intersections(event, ("_end", "_exception")):        │
│ ❱ 365 │   │   │   getattr(self, event)(self)                                 │
│   366 │                                                                      │
│   367 │   @abstractmethod                                                    │
│   368 │   def handle_batch(self, batch: Mapping[str, Any]) -> None:          │
│                                                                              │
│ ~/miniconda3/envs/detekt/lib/python3.9/site-packages/catalyst/co             │
│ re/runner.py:357 in on_exception                                             │
│                                                                              │
│   354 │                                                                      │
│   355 │   def on_exception(self, runner: "IRunner"):                         │
│   356 │   │   """Event handler."""                                           │
│ ❱ 357 │   │   raise self.exception                                           │
│   358 │                                                                      │
│   359 │   def _run_event(self, event: str) -> None:                          │
│   360 │   │   if is_str_intersections(event, ("_start",)):                   │
│                                                                              │
│ ~/miniconda3/envs/detekt/lib/python3.9/site-packages/catalyst/co             │
│ re/runner.py:419 in run                                                      │
│                                                                              │
│   416 │   │   │   self, `IRunner` instance after the experiment              │
│   417 │   │   """                                                            │
│   418 │   │   try:                                                           │
│ ❱ 419 │   │   │   self._run()                                                │
│   420 │   │   except (Exception, KeyboardInterrupt) as ex:                   │
│   421 │   │   │   self.exception = ex                                        │
│   422 │   │   │   self._run_event("on_exception")                            │
│                                                                              │
│ ~/miniconda3/envs/detekt/lib/python3.9/site-packages/catalyst/co             │
│ re/runner.py:410 in _run                                                     │
│                                                                              │
│   407 │                                                                      │
│   408 │   def _run(self) -> None:                                            │
│   409 │   │   self.engine = self.get_engine()                                │
│ ❱ 410 │   │   self.engine.spawn(self._run_local)                             │
│   411 │                                                                      │
│   412 │   def run(self) -> "IRunner":                                        │
│   413 │   │   """Runs the experiment.                                        │
│                                                                              │
│ ~/miniconda3/envs/detekt/lib/python3.9/site-packages/catalyst/co             │
│ re/engine.py:59 in spawn                                                     │
│                                                                              │
│   56 │   │   Returns:                                                        │
│   57 │   │   │   wrapped function (if needed).                               │
│   58 │   │   """                                                             │
│ ❱ 59 │   │   return fn(*args, **kwargs)                                      │
│   60 │                                                                       │
│   61 │   def setup(self, local_rank: int, world_size: int):                  │
│   62 │   │   """Initialize DDP variables and processes if required.          │
│                                                                              │
│~/miniconda3/envs/detekt/lib/python3.9/site-packages/catalyst/co              │
│ re/runner.py:404 in _run_local                                               │
│                                                                              │
│   401 │                                                                      │
│   402 │   def _run_local(self, local_rank: int = -1, world_size: int = 1) -> │
│   403 │   │   self._local_rank, self._world_size = local_rank, world_size    │
│ ❱ 404 │   │   self._run_event("on_experiment_start")                         │
│   405 │   │   self._run_experiment()                                         │
│   406 │   │   self._run_event("on_experiment_end")                           │
│   407                                                                        │
│                                                                              │
│ ~/miniconda3/envs/detekt/lib/python3.9/site-packages/catalyst/co             │
│ re/runner.py:361 in _run_event                                               │
│                                                                              │
│   358 │                                                                      │
│   359 │   def _run_event(self, event: str) -> None:                          │
│   360 │   │   if is_str_intersections(event, ("_start",)):                   │
│ ❱ 361 │   │   │   getattr(self, event)(self)                                 │
│   362 │   │   for callback in self.callbacks.values():                       │
│   363 │   │   │   getattr(callback, event)(self)                             │
│   364 │   │   if is_str_intersections(event, ("_end", "_exception")):        │
│                                                                              │
│ ~/miniconda3/envs/detekt/lib/python3.9/site-packages/catalyst/co             │
│ re/runner.py:278 in on_experiment_start                                      │
│                                                                              │
│   275 │   │   │   self.log_hparams(hparams=self.hparams)                     │
│   276 │   │   with self.engine.local_main_process_first():                   │
│   277 │   │   │   self._setup_loaders()                                      │
│ ❱ 278 │   │   self._setup_components()                                       │
│   279 │   │   self._setup_callbacks()                                        │
│   280 │                                                                      │
│   281 │   def on_epoch_start(self, runner: "IRunner"):                       │
│                                                                              │
│ ~/miniconda3/envs/detekt/lib/python3.9/site-packages/catalyst/co             │
│ re/runner.py:241 in _setup_components                                        │
│                                                                              │
│   238 │   │   self.scheduler = self._setup_scheduler(optimizer=self.optimize │
│   239 │   │                                                                  │
│   240 │   │   if isinstance(self.model, torch.nn.Module):                    │
│ ❱ 241 │   │   │   self.model = self.engine.prepare(self.model)               │
│   242 │   │   elif isinstance(self.model, dict):                             │
│   243 │   │   │   self.model = {k: self.engine.prepare(v) for k, v in self.m │
│   244 │   │   else:                                                          │
│                                                                              │
│ ~/miniconda3/envs/detekt/lib/python3.9/site-packages/accelerate/             │
│ accelerator.py:876 in prepare                                                │
│                                                                              │
│    873 │   │   elif self.distributed_type == DistributedType.MEGATRON_LM:    │
│    874 │   │   │   result = self._prepare_megatron_lm(*args)                 │
│    875 │   │   else:                                                         │
│ ❱  876 │   │   │   result = tuple(                                           │
│    877 │   │   │   │   self._prepare_one(obj, first_pass=True, device_placem │
│    878 │   │   │   )                                                         │
│    879 │   │   │   result = tuple(self._prepare_one(obj, device_placement=d) │
│                                                                              │
│~/miniconda3/envs/detekt/lib/python3.9/site-packages/accelerate/              │
│ accelerator.py:877 in <genexpr>                                              │
│                                                                              │
│    874 │   │   │   result = self._prepare_megatron_lm(*args)                 │
│    875 │   │   else:                                                         │
│    876 │   │   │   result = tuple(                                           │
│ ❱  877 │   │   │   │   self._prepare_one(obj, first_pass=True, device_placem │
│    878 │   │   │   )                                                         │
│    879 │   │   │   result = tuple(self._prepare_one(obj, device_placement=d) │
│    880                                                                       │
│                                                                              │
│ ~/miniconda3/envs/detekt/lib/python3.9/site-packages/accelerate/             │
│ accelerator.py:741 in _prepare_one                                           │
│                                                                              │
│    738 │   │   │   if isinstance(obj, torch.utils.data.DataLoader):          │
│    739 │   │   │   │   return self.prepare_data_loader(obj, device_placement │
│    740 │   │   │   elif isinstance(obj, torch.nn.Module):                    │
│ ❱  741 │   │   │   │   return self.prepare_model(obj, device_placement=devic │
│    742 │   │   │   elif isinstance(obj, torch.optim.Optimizer):              │
│    743 │   │   │   │   optimizer = self.prepare_optimizer(obj, device_placem │
│    744 │   │   │   │   return optimizer                                      │
╰──────────────────────────────────────────────────────────────────────────────╯
TypeError: prepare_model() got an unexpected keyword argument 'device_placement'

Code sample

import torch
from torch.utils.data import DataLoader, TensorDataset
from catalyst import dl

# data
num_samples, num_features = int(1e4), int(1e1)
X, y = torch.rand(num_samples, num_features), torch.rand(num_samples)
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=32, num_workers=1)
loaders = {"train": loader, "valid": loader}

# model, criterion, optimizer, scheduler
model = torch.nn.Linear(num_features, 1)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [3, 6])

# model training
runner = dl.SupervisedRunner()
runner.train(
    model=model,
    engine=DataParallelEngine(fp16=True),
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    loaders=loaders,
    logdir="./logdir",
    valid_loader="valid",
    valid_metric="loss",
    minimize_valid_metric=True,
    num_epochs=8,
    verbose=True,
)

Screenshots

Expected behavior

I'd expect the model to train properly, that is that DataParallelEngine.prepare_model forwards all the kwargs it receives to super().prepare_model().

Environment

The output of the collect_env.py script is the following:

Catalyst version: 22.04
PyTorch version: 1.13.1+cu117
Is debug build: No
CUDA used to build PyTorch: 11.7
TensorFlow version: N/A
TensorBoard version: 2.11.0

OS: Ubuntu 22.04.2 LTS
GCC version: (Ubuntu 11.3.0-1ubuntu1~22.04) 11.3.0
CMake version: Could not collect

Python version: 3.9
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3070 Laptop GPU
Nvidia driver version: 525.85.12
cuDNN version: Could not collect

Versions of relevant libraries:
[pip3] catalyst==22.4
[pip3] numpy==1.23.5
[pip3] pytorch-ranger==0.1.1
[pip3] tensorboard==2.11.0
[pip3] tensorboard-data-server==0.6.1
[pip3] tensorboard-plugin-wit==1.8.1
[pip3] tensorboardX==2.2
[pip3] torch==1.13.1
[pip3] torch-optimizer==0.3.0
[pip3] torchinfo==1.7.2
[pip3] torchvision==0.14.1
[conda] catalyst                  22.4                     pypi_0    pypi
[conda] cudatoolkit               11.7.0              hd8887f6_10    nvidia
[conda] numpy                     1.23.5           py39h3d75532_0    conda-forge
[conda] pytorch-ranger            0.1.1                    pypi_0    pypi
[conda] tensorboard               2.11.0                   pypi_0    pypi
[conda] tensorboard-data-server   0.6.1                    pypi_0    pypi
[conda] tensorboard-plugin-wit    1.8.1                    pypi_0    pypi
[conda] tensorboardx              2.2                      pypi_0    pypi
[conda] torch                     1.13.1                   pypi_0    pypi
[conda] torch-optimizer           0.3.0                    pypi_0    pypi
[conda] torchinfo                 1.7.2                    pypi_0    pypi
[conda] torchvision               0.14.1                   pypi_0    pypi

Additional context

Checklist

FAQ

Please review the FAQ before submitting an issue:

github-actions[bot] commented 1 year ago

Hi! Thank you for your contribution! Please re-check all issue template checklists - unfilled issues would be closed automatically. And do not forget to join our slack for collaboration.

bagxi commented 1 year ago

Please use accelerate==0.5.1