pytorch / ignite

High-level library to help with training and evaluating neural networks in PyTorch flexibly and transparently.
https://pytorch-ignite.ai
BSD 3-Clause "New" or "Revised" License
4.53k stars 615 forks source link

How to change train loader when I have started trainer.run(train_loader, max_epochs=epochs) #1454

Closed feifeifeiliu closed 3 years ago

feifeifeiliu commented 3 years ago

❓ Questions/Help/Support

Hi, thank you for your library. As mentioned in the topic, I want to know how to change train loader when I have started trainer.run(train_loader, max_epochs=epochs). For example, I use train loader A in the first 60 epoches, and after 60 epoches, I want to use train loader B, which has a different sampler with A. Could you tell me how to do that?

vfdev-5 commented 3 years ago

Hi @feifeifeiliu , thanks for the feedback. Here is how it can be done: https://pytorch.org/ignite/engine.html#ignite.engine.engine.Engine.set_data

Please, let me know if it works for you.

feifeifeiliu commented 3 years ago

Hi @vfdev-5 , I use the code segment you provided. It can work, but there is still little problem. Using train loader A, there are 47 iters in an epoch. Using train loader B, there should be 203 iters in an epoch. However, after set_data(B), there are still 47 iters in an epoch. I can sure set_data take effect. Maybe the problem is caused by these code segment. Could you take a look for me? Thank you!


def do_train_with_center(
        cfg,
        model,
        train_loader,
        val_loader,
        optimizer,
        scheduler,
        loss_fn,
        num_query,
        start_epoch,
        regularizer,
        of_penalty,
):
    log_period = cfg.SOLVER.LOG_PERIOD
    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
    eval_period = cfg.SOLVER.EVAL_PERIOD
    output_dir = cfg.OUTPUT_DIR
    device = cfg.MODEL.DEVICE
    epochs = cfg.SOLVER.MAX_EPOCHS
    tloader = train_loader
    train_loader = train_loader[0]

    weight['id'] = 1 - LOSS_WEIGHT
    weight['tri'] = weight['ap'] = LOSS_WEIGHT

    logger = logging.getLogger("reid_baseline.train")
    logger.info("Start training")
    trainer = create_supervised_trainer_with_center(model, optimizer, loss_fn, regularizer, of_penalty,
                                                    device=device)
    evaluator = create_supervised_evaluator(model, metrics={
        'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)}, device=device)
    timer = Timer(average=True)

    if ignite.__version__ in ['0.4.0.post1', '0.4.1', '0.3.0', '0.4.2']:
        checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, n_saved=2, require_empty=False)
        trainer.add_event_handler(Events.EPOCH_COMPLETED(every=checkpoint_period), checkpointer, {'model': model,
                                                                                                  'optimizer': optimizer})
    else:
        checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, checkpoint_period, n_saved=2, require_empty=False)
        trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model,
                                                                         'optimizer': optimizer})

    timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)

    # average metric to attach on trainer
    RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss')
    RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc')
    RunningAverage(output_transform=lambda x: x[2]).attach(trainer, 'avg_id_loss')
    RunningAverage(output_transform=lambda x: x[3]).attach(trainer, 'avg_t_loss')
    RunningAverage(output_transform=lambda x: x[4]).attach(trainer, 'avg_ap_loss')
    RunningAverage(output_transform=lambda x: x[5]).attach(trainer, 'mean')
    RunningAverage(output_transform=lambda x: x[6]).attach(trainer, 'norm')

    @trainer.on(Events.STARTED)
    def start_training(engine):
        engine.state.epoch = start_epoch

    @trainer.on(Events.EPOCH_STARTED)
    def adjust_learning_rate(engine):
        global LOSS_WEIGHT
        global weight
        LOSS_WEIGHT = 1. * engine.state.epoch / cfg.SOLVER.MAX_EPOCHS
        weight = weight_generate(weight, LOSS_WEIGHT)

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        global ITER
        ITER += 1

        if ITER % log_period == 0:
            logger.info(
                "Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, ID Loss: {:.3f}, T Loss: {:.3f}, SAP: {:.3%}, F Mean: {:.3f}, F Norm: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e} " #, AP Loss: {:.3f}, SAP: {:.3%}
                    .format(engine.state.epoch, ITER, len(train_loader),
                            engine.state.metrics['avg_loss'],
                            engine.state.metrics['avg_id_loss'],
                            engine.state.metrics['avg_t_loss'],
                            engine.state.metrics['avg_ap_loss'],
                            engine.state.metrics['mean'],
                            engine.state.metrics['norm'],
                            engine.state.metrics['avg_acc'],
                            scheduler.get_lr()[0]))
        if len(train_loader) == ITER:
            ITER = 0

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        scheduler.step()
        logger.info('Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]'
                    .format(engine.state.epoch, timer.value() * timer.step_count,
                            train_loader.batch_size / timer.value()))
        logger.info('-' * 10)

        timer.reset()

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        if engine.state.epoch % eval_period == 0:
            evaluator.run(val_loader)
            cmc, mAP = evaluator.state.metrics['r1_mAP']
            logger.info("Validation Results - Epoch: {}".format(engine.state.epoch))
            logger.info("mAP: {:.3%}".format(mAP))
            for r in [1, 5, 10]:
                logger.info("CMC curve, Rank-{:<3}:{:.3%}".format(r, cmc[r - 1]))

    @trainer.on(Events.EPOCH_COMPLETED(once=2))
    def switch_dataloader():
        trainer.set_data(tloader[1])

    trainer.run(train_loader, max_epochs=epochs)``
vfdev-5 commented 3 years ago

@feifeifeiliu if you'd like to change epoch length you can do this inside the switch handler like this

data1 = ["a"] * 4
data2 = ["b"] * 20

switch_epoch = 3

def train_step(e, batch):
    print(trainer.state.epoch, trainer.state.iteration, batch, trainer.state.epoch_length)

trainer = Engine(train_step)

@trainer.on(Events.EPOCH_COMPLETED(once=switch_epoch))
def switch_dataloader():
    trainer.set_data(data2)
    # HERE WE SET NEW EPOCH LENGTH
    trainer.state.epoch_length = len(data2)

trainer.run(data1, max_epochs=10)

Hope this helps

feifeifeiliu commented 3 years ago

@vfdev-5 It seems works! Thank you for your patient and kind guidance!

vfdev-5 commented 3 years ago

@feifeifeiliu you are welcome ! Let me close the issue as resolved. Feel free to reopen if needed.