Closed feifeifeiliu closed 3 years ago
Hi @feifeifeiliu , thanks for the feedback. Here is how it can be done: https://pytorch.org/ignite/engine.html#ignite.engine.engine.Engine.set_data
Please, let me know if it works for you.
Hi @vfdev-5 , I use the code segment you provided. It can work, but there is still little problem. Using train loader A, there are 47 iters in an epoch. Using train loader B, there should be 203 iters in an epoch. However, after set_data(B), there are still 47 iters in an epoch. I can sure set_data take effect. Maybe the problem is caused by these code segment. Could you take a look for me? Thank you!
def do_train_with_center(
cfg,
model,
train_loader,
val_loader,
optimizer,
scheduler,
loss_fn,
num_query,
start_epoch,
regularizer,
of_penalty,
):
log_period = cfg.SOLVER.LOG_PERIOD
checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
eval_period = cfg.SOLVER.EVAL_PERIOD
output_dir = cfg.OUTPUT_DIR
device = cfg.MODEL.DEVICE
epochs = cfg.SOLVER.MAX_EPOCHS
tloader = train_loader
train_loader = train_loader[0]
weight['id'] = 1 - LOSS_WEIGHT
weight['tri'] = weight['ap'] = LOSS_WEIGHT
logger = logging.getLogger("reid_baseline.train")
logger.info("Start training")
trainer = create_supervised_trainer_with_center(model, optimizer, loss_fn, regularizer, of_penalty,
device=device)
evaluator = create_supervised_evaluator(model, metrics={
'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)}, device=device)
timer = Timer(average=True)
if ignite.__version__ in ['0.4.0.post1', '0.4.1', '0.3.0', '0.4.2']:
checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, n_saved=2, require_empty=False)
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=checkpoint_period), checkpointer, {'model': model,
'optimizer': optimizer})
else:
checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, checkpoint_period, n_saved=2, require_empty=False)
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model,
'optimizer': optimizer})
timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)
# average metric to attach on trainer
RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss')
RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc')
RunningAverage(output_transform=lambda x: x[2]).attach(trainer, 'avg_id_loss')
RunningAverage(output_transform=lambda x: x[3]).attach(trainer, 'avg_t_loss')
RunningAverage(output_transform=lambda x: x[4]).attach(trainer, 'avg_ap_loss')
RunningAverage(output_transform=lambda x: x[5]).attach(trainer, 'mean')
RunningAverage(output_transform=lambda x: x[6]).attach(trainer, 'norm')
@trainer.on(Events.STARTED)
def start_training(engine):
engine.state.epoch = start_epoch
@trainer.on(Events.EPOCH_STARTED)
def adjust_learning_rate(engine):
global LOSS_WEIGHT
global weight
LOSS_WEIGHT = 1. * engine.state.epoch / cfg.SOLVER.MAX_EPOCHS
weight = weight_generate(weight, LOSS_WEIGHT)
@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(engine):
global ITER
ITER += 1
if ITER % log_period == 0:
logger.info(
"Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, ID Loss: {:.3f}, T Loss: {:.3f}, SAP: {:.3%}, F Mean: {:.3f}, F Norm: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e} " #, AP Loss: {:.3f}, SAP: {:.3%}
.format(engine.state.epoch, ITER, len(train_loader),
engine.state.metrics['avg_loss'],
engine.state.metrics['avg_id_loss'],
engine.state.metrics['avg_t_loss'],
engine.state.metrics['avg_ap_loss'],
engine.state.metrics['mean'],
engine.state.metrics['norm'],
engine.state.metrics['avg_acc'],
scheduler.get_lr()[0]))
if len(train_loader) == ITER:
ITER = 0
# adding handlers using `trainer.on` decorator API
@trainer.on(Events.EPOCH_COMPLETED)
def print_times(engine):
scheduler.step()
logger.info('Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]'
.format(engine.state.epoch, timer.value() * timer.step_count,
train_loader.batch_size / timer.value()))
logger.info('-' * 10)
timer.reset()
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
if engine.state.epoch % eval_period == 0:
evaluator.run(val_loader)
cmc, mAP = evaluator.state.metrics['r1_mAP']
logger.info("Validation Results - Epoch: {}".format(engine.state.epoch))
logger.info("mAP: {:.3%}".format(mAP))
for r in [1, 5, 10]:
logger.info("CMC curve, Rank-{:<3}:{:.3%}".format(r, cmc[r - 1]))
@trainer.on(Events.EPOCH_COMPLETED(once=2))
def switch_dataloader():
trainer.set_data(tloader[1])
trainer.run(train_loader, max_epochs=epochs)``
@feifeifeiliu if you'd like to change epoch length you can do this inside the switch handler like this
data1 = ["a"] * 4
data2 = ["b"] * 20
switch_epoch = 3
def train_step(e, batch):
print(trainer.state.epoch, trainer.state.iteration, batch, trainer.state.epoch_length)
trainer = Engine(train_step)
@trainer.on(Events.EPOCH_COMPLETED(once=switch_epoch))
def switch_dataloader():
trainer.set_data(data2)
# HERE WE SET NEW EPOCH LENGTH
trainer.state.epoch_length = len(data2)
trainer.run(data1, max_epochs=10)
Hope this helps
@vfdev-5 It seems works! Thank you for your patient and kind guidance!
@feifeifeiliu you are welcome ! Let me close the issue as resolved. Feel free to reopen if needed.
❓ Questions/Help/Support
Hi, thank you for your library. As mentioned in the topic, I want to know how to change train loader when I have started trainer.run(train_loader, max_epochs=epochs). For example, I use train loader A in the first 60 epoches, and after 60 epoches, I want to use train loader B, which has a different sampler with A. Could you tell me how to do that?