microsoft / nni

An open source AutoML toolkit for automate machine learning lifecycle, including feature engineering, neural architecture search, model compression and hyper-parameter tuning.
https://nni.readthedocs.io
MIT License
14.05k stars 1.82k forks source link

Proxylessnas ModelHooks.on_train_batch_start() issue #5055

Closed AL3708 closed 2 years ago

AL3708 commented 2 years ago

I have an example model from docs:

import nni.retiarii.evaluator.pytorch.lightning as pl
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import model_wrapper
import torch.nn.functional as F

class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.depthwise = nn.Conv2d(in_ch, in_ch, kernel_size=3, groups=in_ch)
        self.pointwise = nn.Conv2d(in_ch, out_ch, kernel_size=1)

    def forward(self, x):
        return self.pointwise(self.depthwise(x))

@model_wrapper  # this decorator should be put on the out most
class Net(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.LayerChoice([
            nn.Conv2d(32, 64, 3, 1),
            DepthwiseSeparableConv(32, 64)
        ])
        self.dropout1 = nn.Dropout(0.5)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 1000)
        self.fc2 = nn.Linear(1000, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(self.conv2(x), 2)
        x = torch.flatten(self.dropout1(x), 1)
        x = self.fc2(self.dropout2(F.relu(self.fc1(x))))
        return F.log_softmax(x, dim=1)

If i run experiment with Random Strategy it works fine:

import nni
import nni.retiarii.strategy as strategy
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig
from torchvision import transforms
from torchvision.datasets import MNIST

transform = nni.trace(transforms.Compose)([nni.trace(transforms.ToTensor)(), nni.trace(transforms.Normalize)((0.1307,), (0.3081,))])
train_dataset = nni.trace(MNIST)(root='data/mnist', train=True, download=True, transform=transform)
test_dataset = nni.trace(MNIST)('data/mnist', train=False, transform=transform)
evaluator = pl.Classification(train_dataloaders=pl.DataLoader(train_dataset, batch_size=100),
                              val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
                              accelerator='gpu',max_epochs=10)

model = Net()
search_strategy = strategy.Random()
# search_strategy = strategy.Proxyless()
exp = RetiariiExperiment(model, evaluator, [], search_strategy)
exp_config = RetiariiExeConfig('local')
exp_config.experiment_name = 'mnist_search'
# exp_config.execution_engine = 'oneshot'
exp_config.max_trial_number = 1 
exp_config.trial_concurrency = 1 
exp_config.trial_gpu_number = 1
exp_config.training_service.use_active_gpu = True
exp.run(exp_config, 8081)

but when I switch to Proxyless, then an error will occur:

# search_strategy = strategy.Random()
search_strategy = strategy.Proxyless()
exp = RetiariiExperiment(model, evaluator, [], search_strategy)
exp_config = RetiariiExeConfig('local')
exp_config.experiment_name = 'mnist_search'
exp_config.execution_engine = 'oneshot'
exp_config.max_trial_number = 1 
exp_config.trial_concurrency = 1 
exp_config.trial_gpu_number = 1
exp_config.training_service.use_active_gpu = True
exp.run(exp_config, 8081)

Error:

Tensorflow is not installed.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
C:\Users\...\lib\site-packages\pytorch_lightning\trainer\configuration_validator.py:228: LightningDeprecationWarning: The `LightningModule.on_epoch_start` hook was deprecated in v1.6 and will be removed in v1.8. Please use `LightningModule.on_<train/validation/test>_epoch_start` instead.
  rank_zero_deprecation(
C:\Users\...\lib\site-packages\pytorch_lightning\trainer\configuration_validator.py:228: LightningDeprecationWarning: The `LightningModule.on_epoch_end` hook was deprecated in v1.6 and will be removed in v1.8. Please use `LightningModule.on_<train/validation/test>_epoch_end` instead.
  rank_zero_deprecation(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  | Name  | Type                  | Params
------------------------------------------------
0 | model | _ClassificationModule | 9.2 M 
------------------------------------------------
9.2 M     Trainable params
0         Non-trainable params
9.2 M     Total params
36.993    Total estimated model params size (MB)
C:\Users\...\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:219: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 6 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
Epoch 0:   0%|          | 0/600 [00:00<?, ?it/s] Traceback (most recent call last):
  File "C:\Users\...\lib\site-packages\IPython\core\interactiveshell.py", line 3398, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-2-40724acf6027>", line 1, in <cell line: 1>
    runfile('C:/Users/.../scripts/proxylessnas.py', wdir='C:/Users/.../scripts')
  File "C:\Program Files\JetBrains\PyCharm 2022.1.3\plugins\python\helpers\pydev\_pydev_bundle\pydev_umd.py", line 198, in runfile
    pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
  File "C:\Program Files\JetBrains\PyCharm 2022.1.3\plugins\python\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "C:/Users/.../scripts/proxylessnas.py", line 184, in <module>
    exp.run(exp_config, 8081)
  File "C:\Users\...\lib\site-packages\nni\retiarii\experiment\pytorch.py", line 289, in run
    self.strategy.run(base_model_ir, self.applied_mutators)
  File "C:\Users\...\lib\site-packages\nni\retiarii\oneshot\pytorch\strategy.py", line 76, in run
    evaluator.trainer.fit(self.model, train_loader, val_loader)
  File "C:\Users\...\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 700, in fit
    self._call_and_handle_interrupt(
  File "C:\Users\...\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 654, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "C:\Users\...\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 741, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "C:\Users\...\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1166, in _run
    results = self._run_stage()
  File "C:\Users\...\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1252, in _run_stage
    return self._run_train()
  File "C:\Users\...\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1282, in _run_train
    self.fit_loop.run()
  File "C:\Users\...\lib\site-packages\pytorch_lightning\loops\loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "C:\Users\...\lib\site-packages\pytorch_lightning\loops\fit_loop.py", line 269, in advance
    self._outputs = self.epoch_loop.run(self._data_fetcher)
  File "C:\Users\...\lib\site-packages\pytorch_lightning\loops\loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "C:\Users\...\lib\site-packages\pytorch_lightning\loops\epoch\training_epoch_loop.py", line 194, in advance
    response = self.trainer._call_lightning_module_hook("on_train_batch_start", batch, batch_idx)
  File "C:\Users\...\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1549, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "C:\Users\...\lib\site-packages\nni\retiarii\oneshot\pytorch\base_lightning.py", line 378, in on_train_batch_start
    return self.model.on_train_batch_start(batch, batch_idx, unused)
TypeError: ModelHooks.on_train_batch_start() takes 3 positional arguments but 4 were given

Environment:

How to reproduce it?: Use code above

AL3708 commented 2 years ago

Issue exists for all one-shot nas. I've solved it temporarily by modifying BaseOneShotLightningModule class in base_lightning.py. I removed unused argument from methods:

    def on_train_batch_start(self, batch, batch_idx, unused=0):
        # return self.model.on_train_batch_start(batch, batch_idx, unused)
        return self.model.on_train_batch_start(batch, batch_idx)

    def on_train_batch_end(self, outputs, batch, batch_idx, unused=0):
        # return self.model.on_train_batch_end(outputs, batch, batch_idx, unused)
        return self.model.on_train_batch_end(outputs, batch, batch_idx)
ultmaster commented 2 years ago

This is due to an API change in PyTorch-lightning in v1.7.

Your temporary solution has already been merged into master and will be released soon. For now, try using lightning < v1.7, or do some hackings like you just did.