ProxylessNas on CIFAR #5667

John1231983 commented 1 year ago

Hello everyone, I would like to test ProxylessNAS on a smaller dataset like CIFAR using NNI version 3.0. I used the code that inherits from the DARTS example. However, it does not work. Could you please help me figure out how to fix this issue? Here is my code:

RuntimeError: Shape inference failed because no shape inference formula is found for AvgPool2d(kernel_size=3, stride=1, padding=1) of type AvgPool2d. Meanwhile the nested modules and functions inside failed to propagate the shape information. Please provide a `_shape_forward` member function or register a formula using `register_shape_inference_formula`.

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

Reproduction of experiments in `DARTS paper <>`__.

import argparse
import json
import os

import numpy as np
import torch
import nni
from nni.nas.evaluator.pytorch import Lightning, ClassificationModule, Trainer
from nni.nas.experiment import NasExperiment
from import model_context
from nni.nas.hub.pytorch import DARTS
from nni.nas.strategy import DARTS as DartsStrategy
from pytorch_lightning.loggers import TensorBoardLogger
from import DataLoader
from import SubsetRandomSampler
from torchvision import transforms
from torchvision.datasets import CIFAR10

class AuxLossClassificationModule(ClassificationModule):
    """Several customization for the training of DARTS, based on default Classification."""
    model: DARTS

    def __init__(self,
                 learning_rate: float = 0.001,
                 weight_decay: float = 0.,
                 auxiliary_loss_weight: float = 0.4,
                 max_epochs: int = 600):
        self.auxiliary_loss_weight = auxiliary_loss_weight
        self.max_epochs = max_epochs
        super().__init__(learning_rate=learning_rate, weight_decay=weight_decay, num_classes=10)

    def configure_optimizers(self):
        """Customized optimizer with momentum, as well as a scheduler."""
        optimizer = torch.optim.SGD(
            lr=self.hparams.learning_rate,  # type: ignore
            weight_decay=self.hparams.weight_decay  # type: ignore
        return {
            'optimizer': optimizer,
            'lr_scheduler': torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, self.max_epochs, eta_min=1e-3)

    def training_step(self, batch, batch_idx):
        """Training step, customized with auxiliary loss."""
        x, y = batch
        if self.auxiliary_loss_weight:
            y_hat, y_aux = self(x)
            loss_main = self.criterion(y_hat, y)
            loss_aux = self.criterion(y_aux, y)
            self.log('train_loss_main', loss_main)
            self.log('train_loss_aux', loss_aux)
            loss = loss_main + self.auxiliary_loss_weight * loss_aux
            y_hat = self(x)
            loss = self.criterion(y_hat, y)
        self.log('train_loss', loss, prog_bar=True)
        for name, metric in self.metrics.items():
            self.log('train_' + name, metric(y_hat, y), prog_bar=True)
        return loss

    def on_train_epoch_start(self):
        """Set drop path probability before every epoch. This has no effect if drop path is not enabled in model."""
        self.model.set_drop_path_prob(self.model.drop_path_prob * self.current_epoch / self.max_epochs)

        # Logging learning rate at the beginning of every epoch
        self.log('lr', self.trainer.optimizers[0].param_groups[0]['lr'])

def cutout_transform(img, length: int = 16):
    h, w = img.size(1), img.size(2)
    mask = np.ones((h, w), np.float32)
    y = np.random.randint(h)
    x = np.random.randint(w)

    y1 = np.clip(y - length // 2, 0, h)
    y2 = np.clip(y + length // 2, 0, h)
    x1 = np.clip(x - length // 2, 0, w)
    x2 = np.clip(x + length // 2, 0, w)

    mask[y1: y2, x1: x2] = 0.
    mask = torch.from_numpy(mask)
    mask = mask.expand_as(img)
    img *= mask
    return img

def get_cifar10_dataset(train: bool = True, cutout: bool = False):
    CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
    CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]

    if train:
        transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
        if cutout:
        transform = transforms.Compose([
            transforms.Normalize(CIFAR_MEAN, CIFAR_STD),

    return nni.trace(CIFAR10)(root='./data', train=train, download=True, transform=transform)

def search(log_dir: str, batch_size: int = 64, **kwargs):
    model_space = DARTS(16, 8, 'cifar')

    train_data = get_cifar10_dataset()
    num_samples = len(train_data)
    indices = np.random.permutation(num_samples)
    split = num_samples // 2

    train_loader = DataLoader(
        train_data, batch_size=batch_size,
        pin_memory=True, num_workers=6

    valid_loader = DataLoader(
        train_data, batch_size=batch_size,
        pin_memory=True, num_workers=6

    evaluator = Lightning(
        AuxLossClassificationModule(0.025, 3e-4, 0., 50),
            logger=TensorBoardLogger(log_dir, name='search')

    # # Gradient clip needs to be put here because DARTS strategy doesn't support this configuration from trainer.
    # strategy = DartsStrategy(gradient_clip_val=5.)

    # # from import RawFormatModelSpace
    # # from nni.nas.execution import SequentialExecutionEngine
    # # engine = SequentialExecutionEngine()

    # # strategy(RawFormatModelSpace(model_space, evaluator), engine)

    # # print(next(strategy.list_models()).sample)

    # experiment = NasExperiment(model_space, evaluator, strategy)

    # # return next(strategy.list_models()).sample

    import json
    from nni.nas.execution import SequentialExecutionEngine
    from import RawFormatModelSpace
    from nni.nas.oneshot.pytorch.profiler import ExpectationProfilerPenalty
    from nni.nas.strategy import Proxyless

    from nni.nas.profiler.pytorch.flops import FlopsProfiler
    print (model_space)
    profiler = FlopsProfiler(model_space, torch.randn(1, 3, 32, 32), count_normalization=False, count_bias=False, count_activation=False)
    penalty = ExpectationProfilerPenalty(profiler, 320e6, scale=0.1, nonlinear='absolute')

    engine = SequentialExecutionEngine()
    strategy = Proxyless(warmup_epochs=20, penalty=penalty, arc_learning_rate=1e-3)

    strategy(RawFormatModelSpace(model_space, evaluator), engine)

    arch = next(strategy.list_models()).sample

    with open(os.path.join(log_dir, 'arch.json'), 'w') as f:
        json.dump(arch, f)

def train(arch: dict, log_dir: str, batch_size: int = 96, ckpt_path: str = None, **kwargs):
    with model_context(arch):
        model = DARTS(36, 20, 'cifar', auxiliary_loss=True, drop_path_prob=0.2)

    train_data = get_cifar10_dataset(cutout=True)
    valid_data = get_cifar10_dataset(train=False)

    fit_kwargs = {}
    if ckpt_path:
        fit_kwargs['ckpt_path'] = ckpt_path

    evaluator = Lightning(
        AuxLossClassificationModule(0.025, 3e-4, 0.4, 600),
            logger=TensorBoardLogger(log_dir, name='train')
        train_dataloaders=DataLoader(train_data, batch_size=batch_size, pin_memory=True, shuffle=True, num_workers=6),
        val_dataloaders=DataLoader(valid_data, batch_size=batch_size, pin_memory=True, num_workers=6),

def test(arch, weight_file, batch_size: int = 512, **kwargs):
    with model_context(arch):
        model = DARTS(36, 20, 'cifar')

    lightning_module = AuxLossClassificationModule(0.025, 3e-4, 0., 600)
    trainer = Trainer(gpus=1)

    valid_data = get_cifar10_dataset(train=False)
    valid_loader = DataLoader(valid_data, batch_size=batch_size, pin_memory=True, num_workers=6)

    trainer.validate(lightning_module, valid_loader)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--mode', choices=['search', 'train', 'test', 'search_train'], default='search_train')
    parser.add_argument('--batch_size', type=int)
    parser.add_argument('--arch', type=str)
    parser.add_argument('--weight_file', type=str)
    parser.add_argument('--log_dir', default='lightning_logs', type=str)
    parser.add_argument('--ckpt_path', type=str)

    parsed_args = parser.parse_args()
    config = {k: v for k, v in vars(parsed_args).items() if v is not None}
    if 'arch' in config:
        config['arch'] = json.loads(config['arch'])

    if 'search' in config['mode']:
        config['arch'] = search(**config)
        json.dump(config['arch'], open(os.path.join(config['log_dir'], 'arch.json'), 'w'))
        print('Searched config', config['arch'])
    if 'train' in config['mode']:
    if config['mode'] == 'test':

if __name__ == '__main__':