Open Yuyan-Li opened 5 years ago
anyone an opinion on merging this?
If the deserialisation works (which @Yuyan-Li has not tested yet), sure. @DerThorsten would you have time to do a quick test?
Let's see if @Yuyan-Li whants to contribute a test, if not I'll write one soonish
I can test it on my system but I don't know how to write proper unittests. I could write a sample script showing that it works if that's enough.
So I tested it and the deserialisation works. I also fixed it so that it shows the proper epoch in the bar when continuing the training.
But there seems to be a problem with continuing the training after loading a checkpoint. My script crashes at the end with an error (ConnectionRefusedError: [Errno 111] Connection refused
) in the dataloader. This should be unrelated to the TQDM bar because it also happens without it.
I put the script I used below (it's heavily borrowd from tests/test_training/test_basic.py
). Maybe someone has time to take a look.
def _make_test_model():
import torch.nn as nn
from inferno.extensions.layers.reshape import AsMatrix
toy_net = nn.Sequential(nn.Conv2d(3, 8, 3, 1, 1),
nn.ELU(),
nn.MaxPool2d(2),
nn.Conv2d(8, 8, 3, 1, 1),
nn.ELU(),
nn.MaxPool2d(2),
nn.Conv2d(8, 16, 3, 1, 1),
nn.ELU(),
nn.AdaptiveAvgPool2d((1, 1)),
AsMatrix(),
nn.Linear(16, 10))
return toy_net
def test_serialization():
from inferno.trainers.basic import Trainer
from inferno.trainers.callbacks import TQDMProgressBar
from inferno.io.box.cifar import get_cifar10_loaders
# Make model
net = _make_test_model()
# Make trainer
trainer = Trainer(model=net) \
.build_optimizer('Adam') \
.build_criterion('CrossEntropyLoss') \
.build_metric('CategoricalError') \
.validate_every((1, 'epochs')) \
.save_every((1, 'epochs'), to_directory='saves') \
.set_max_num_iterations(500) \
.register_callback(TQDMProgressBar())
train_loader, validate_loader = get_cifar10_loaders(root_directory='.', download=True)
trainer.bind_loader('train', train_loader)
trainer.bind_loader('validate', validate_loader)
# Try to train
trainer.fit()
# Try to serialize
trainer.save()
def test_deserialization():
from inferno.trainers.basic import Trainer
from inferno.io.box.cifar import get_cifar10_loaders
net = _make_test_model()
# Try to unserialize
trainer = Trainer(net).save_to_directory('saves').load()
train_loader, validate_loader = get_cifar10_loaders(root_directory='.', download=True)
trainer.bind_loader('train', train_loader)
trainer.bind_loader('validate', validate_loader)
# Try to continue training
trainer.set_max_num_iterations(800)
trainer.fit()
if __name__=='__main__':
test_serialization()
test_deserialization()
The "training epoch x" bar isn't restored correctly. If you set trainer.set_max_num_iterations(800) to something larger you will notice
Training epoch 1: : 500it [00:16, 29.77it/s] | 2/1000 [00:16<2:20:35, 8.45s/it]
even though each epoch only has 391 iterations.
This removes the TQDM bar from the serialization.
It prevents the error when saving the trainer:
TypeError: cannot serialize '_io.TextIOWrapper' object
I think the bar will be rebuilt automatically (haven't tested it yet).