pytorch / ignite

High-level library to help with training and evaluating neural networks in PyTorch flexibly and transparently.
https://pytorch-ignite.ai
BSD 3-Clause "New" or "Revised" License
4.51k stars 609 forks source link

Variable epoch_length for different epochs #1637

Open pmneila opened 3 years ago

pmneila commented 3 years ago

❓ Questions/Help/Support

Hi,

I'm working with a model that increases in complexity during training. To avoid memory issues, I reduce the batch size accordingly at each epoch. This means that, for a fixed length of the dataset, the number of iterations per epoch increases each epoch.

Something like this:


batch_size_per_epoch = [16, 8, 4, 2]
dataset = ImageDataset(...)

loaders = (DataLoader(dataset, batch_size=bs, shuffle=True) for bs in batch_size_per_epoch)

# Then I can run the engine either with
for i, loader in enumerate(loaders):
    engine.run(loader, max_epochs=i+1)
# or by calling engine.set_data in a properly defined event handler.

The problem is that engine.state.epoch_length is set once for the first loader and the subsequent loaders run as many iterations as the first one. Setting engine.state.epoch_length by hand is not only ugly, but also messes up the saving/loading of the engine (epoch and iterations are inferred assuming a fixed epoch length).

Is there any way to use variable epoch lengths or variable batch sizes with ignite? I've been thinking of building a new engine for each epoch, but keeping the state from previous engines, saving, loading and reusing the loggers/metrics/handlers is rather messy. Is there an alternative?

Best

vfdev-5 commented 3 years ago

@pmneila thanks for the question! In some sense, epoch lenght can be unrelated to the size of the dataset, taking this assumption we can set it and just reset input data each epoch (Engine.set_data). However, the epoch_length will remain fixed.

import torch
from ignite.engine import Engine, Events

n_samples = 100
batch_size_per_epoch = [16, 8, 4, 2]

loaders = [torch.rand(n_samples // bs, bs, 3, 32, 32) for bs in batch_size_per_epoch]

trainer = Engine(lambda e, b: print(f"{e.state.epoch} - {e.state.iteration} : {b.shape}"))
trainer.state.loader_index = 0

@trainer.on(Events.EPOCH_COMPLETED)
def set_next_loader():
    trainer.state.loader_index += 1
    print(f"Set next loader: {trainer.state.loader_index}")
    trainer.set_data(
        loaders[trainer.state.loader_index]
    )

trainer.run(loaders[trainer.state.loader_index], max_epochs=3)

gives

1 - 1 : torch.Size([16, 3, 32, 32])
1 - 2 : torch.Size([16, 3, 32, 32])
1 - 3 : torch.Size([16, 3, 32, 32])
1 - 4 : torch.Size([16, 3, 32, 32])
1 - 5 : torch.Size([16, 3, 32, 32])
1 - 6 : torch.Size([16, 3, 32, 32])
Set next loader: 1
2 - 7 : torch.Size([8, 3, 32, 32])
2 - 8 : torch.Size([8, 3, 32, 32])
2 - 9 : torch.Size([8, 3, 32, 32])
2 - 10 : torch.Size([8, 3, 32, 32])
2 - 11 : torch.Size([8, 3, 32, 32])
2 - 12 : torch.Size([8, 3, 32, 32])
Set next loader: 2
3 - 13 : torch.Size([4, 3, 32, 32])
3 - 14 : torch.Size([4, 3, 32, 32])
3 - 15 : torch.Size([4, 3, 32, 32])
3 - 16 : torch.Size([4, 3, 32, 32])
3 - 17 : torch.Size([4, 3, 32, 32])
3 - 18 : torch.Size([4, 3, 32, 32])
Set next loader: 3

So, effectively for smaller batches we do not cover all the data... Let me think more if we could set epoch length with set_data ...

Maybe,

@trainer.on(Events.EPOCH_COMPLETED)
def set_next_loader():
    trainer.state.loader_index += 1
    print(f"Set next loader: {trainer.state.loader_index}")
    new_loader = loaders[trainer.state.loader_index]
    trainer.set_data(new_loader)
    trainer.state.epoch_length = len(new_loader)
pmneila commented 3 years ago

Thank you for the answer. That code works, but I think it does not deal well with load_state_dict. In Engine.load_state_dict either the epoch or the iterations are computed assuming a fixed epoch_length, and something similar happens in Engine._setup_engine to compute the initial counter value for _run_once_on_dataset.

I made a quick test with load_state_dict with your example:

import torch
from ignite.engine import Engine, Events

n_samples = 100
batch_size_per_epoch = [32, 16, 8, 4, 2]

loaders = [torch.rand(n_samples // bs, bs, 3, 32, 32) for bs in batch_size_per_epoch]

trainer = Engine(lambda e, b: print(f"{e.state.epoch} - {e.state.iteration} : {b.shape}"))
trainer.state.loader_index = 0

@trainer.on(Events.EPOCH_COMPLETED)
def set_next_loader():
    trainer.state.loader_index += 1
    print(f"Set next loader: {trainer.state.loader_index}")
    new_loader = loaders[trainer.state.loader_index]
    trainer.set_data(new_loader)
    trainer.state.epoch_length = len(new_loader)

trainer.run(loaders[trainer.state.loader_index], max_epochs=2)

print("\nBefore load_state_dict...")
print(f"trainer.state.epoch={trainer.state.epoch}")  # Prints: 2
print(f"trainer.state.iteration={trainer.state.iteration}")  # Prints: 9
sd = trainer.state_dict()
trainer.load_state_dict(sd)
print("\nAfter load_state_dict...")
print(f"trainer.state.epoch={trainer.state.epoch}")  # Prints: 0 (should be 2)
print(f"trainer.state.iteration={trainer.state.iteration}")  # Prints: 9

# This should not run, but it runs and the first epoch starts with iter_counter=9 in _run_once_on_dataset (should be 0)
trainer.run(loaders[trainer.state.loader_index], max_epochs=2)

The output is:

iter_counter=0
1 - 1 : torch.Size([32, 3, 32, 32])
1 - 2 : torch.Size([32, 3, 32, 32])
1 - 3 : torch.Size([32, 3, 32, 32])
Set next loader: 1
iter_counter=0
2 - 4 : torch.Size([16, 3, 32, 32])
2 - 5 : torch.Size([16, 3, 32, 32])
2 - 6 : torch.Size([16, 3, 32, 32])
2 - 7 : torch.Size([16, 3, 32, 32])
2 - 8 : torch.Size([16, 3, 32, 32])
2 - 9 : torch.Size([16, 3, 32, 32])
Set next loader: 2

Before load_state_dict...
trainer.state.epoch=2
trainer.state.iteration=9

After load_state_dict...
trainer.state.epoch=0
trainer.state.iteration=9
iter_counter=9
1 - 10 : torch.Size([8, 3, 32, 32])
1 - 11 : torch.Size([8, 3, 32, 32])
1 - 12 : torch.Size([8, 3, 32, 32])
Set next loader: 3
iter_counter=0
2 - 13 : torch.Size([4, 3, 32, 32])
2 - 14 : torch.Size([4, 3, 32, 32])
2 - 15 : torch.Size([4, 3, 32, 32])
2 - 16 : torch.Size([4, 3, 32, 32])
2 - 17 : torch.Size([4, 3, 32, 32])
2 - 18 : torch.Size([4, 3, 32, 32])
2 - 19 : torch.Size([4, 3, 32, 32])
2 - 20 : torch.Size([4, 3, 32, 32])
2 - 21 : torch.Size([4, 3, 32, 32])
2 - 22 : torch.Size([4, 3, 32, 32])
2 - 23 : torch.Size([4, 3, 32, 32])
2 - 24 : torch.Size([4, 3, 32, 32])
2 - 25 : torch.Size([4, 3, 32, 32])
2 - 26 : torch.Size([4, 3, 32, 32])
2 - 27 : torch.Size([4, 3, 32, 32])
2 - 28 : torch.Size([4, 3, 32, 32])
2 - 29 : torch.Size([4, 3, 32, 32])
2 - 30 : torch.Size([4, 3, 32, 32])
2 - 31 : torch.Size([4, 3, 32, 32])
2 - 32 : torch.Size([4, 3, 32, 32])
2 - 33 : torch.Size([4, 3, 32, 32])
2 - 34 : torch.Size([4, 3, 32, 32])
2 - 35 : torch.Size([4, 3, 32, 32])
2 - 36 : torch.Size([4, 3, 32, 32])
2 - 37 : torch.Size([4, 3, 32, 32])
Set next loader: 4

Note that I added an additional print(f"iter_counter={iter_counter}") in _run_once_on_dataset for debugging.

Any suggestions? I cannot think of a simple solution to this without rewriting parts of Engine.

sdesrozis commented 3 years ago

@pnmeila thanks for this discussion.

I think having such way to handle dynamic batch size and epoch length could lead to implementation as in the following paper

https://arxiv.org/abs/1711.00489

vfdev-5 commented 3 years ago

@pmneila well, I agree there is no proper way to do that. Here is a hacky approach to achieve what you'd like

import torch
from ignite.engine import Engine, Events

n_samples = 100
batch_size_per_epoch = [32, 16, 8, 4, 2]

loaders = [torch.rand(n_samples // bs, bs, 3, 32, 32) for bs in batch_size_per_epoch]

trainer = Engine(lambda e, b: print(f"{e.state.epoch} - {e.state.iteration} : {b.shape}"))
trainer.state.loader_index = 0

@trainer.on(Events.EPOCH_COMPLETED)
def set_next_loader():
    trainer.state.loader_index += 1
    print(f"Set next loader: {trainer.state.loader_index}")
    new_loader = loaders[trainer.state.loader_index]
    trainer.set_data(new_loader)
    trainer.state.epoch_length = len(new_loader)

trainer.run(loaders[trainer.state.loader_index], max_epochs=2)

print("\nBefore load_state_dict...")
print(f"trainer.state.epoch={trainer.state.epoch}")  # Prints: 2
print(f"trainer.state.iteration={trainer.state.iteration}")  # Prints: 9
sd = trainer.state_dict()
last_epoch = trainer.state.epoch

trainer.load_state_dict(sd)
# Set explicitly the epoch
trainer.state.epoch = last_epoch

print("\nAfter load_state_dict...")
print(f"trainer.state.epoch={trainer.state.epoch}")  # Prints: 0 (should be 2)
print(f"trainer.state.iteration={trainer.state.iteration}")  # Prints: 9

# we have to restart the data when starts the engine
# such that it avoids calling `_setup_engine()`
@trainer.on(Events.STARTED)
def reset_things():
    new_loader = loaders[trainer.state.loader_index]
    trainer.set_data(new_loader)
    trainer.state.epoch_length = len(new_loader)

trainer.run(loaders[trainer.state.loader_index], max_epochs=4)
Output ``` 1 - 1 : torch.Size([32, 3, 32, 32]) 1 - 2 : torch.Size([32, 3, 32, 32]) 1 - 3 : torch.Size([32, 3, 32, 32]) Set next loader: 1 2 - 4 : torch.Size([16, 3, 32, 32]) 2 - 5 : torch.Size([16, 3, 32, 32]) 2 - 6 : torch.Size([16, 3, 32, 32]) 2 - 7 : torch.Size([16, 3, 32, 32]) 2 - 8 : torch.Size([16, 3, 32, 32]) 2 - 9 : torch.Size([16, 3, 32, 32]) Set next loader: 2 Before load_state_dict... trainer.state.epoch=2 trainer.state.iteration=9 After load_state_dict... trainer.state.epoch=2 trainer.state.iteration=9 3 - 10 : torch.Size([8, 3, 32, 32]) 3 - 11 : torch.Size([8, 3, 32, 32]) 3 - 12 : torch.Size([8, 3, 32, 32]) 3 - 13 : torch.Size([8, 3, 32, 32]) 3 - 14 : torch.Size([8, 3, 32, 32]) 3 - 15 : torch.Size([8, 3, 32, 32]) 3 - 16 : torch.Size([8, 3, 32, 32]) 3 - 17 : torch.Size([8, 3, 32, 32]) 3 - 18 : torch.Size([8, 3, 32, 32]) 3 - 19 : torch.Size([8, 3, 32, 32]) 3 - 20 : torch.Size([8, 3, 32, 32]) 3 - 21 : torch.Size([8, 3, 32, 32]) Set next loader: 3 4 - 22 : torch.Size([4, 3, 32, 32]) 4 - 23 : torch.Size([4, 3, 32, 32]) 4 - 24 : torch.Size([4, 3, 32, 32]) 4 - 25 : torch.Size([4, 3, 32, 32]) 4 - 26 : torch.Size([4, 3, 32, 32]) 4 - 27 : torch.Size([4, 3, 32, 32]) 4 - 28 : torch.Size([4, 3, 32, 32]) 4 - 29 : torch.Size([4, 3, 32, 32]) 4 - 30 : torch.Size([4, 3, 32, 32]) 4 - 31 : torch.Size([4, 3, 32, 32]) 4 - 32 : torch.Size([4, 3, 32, 32]) 4 - 33 : torch.Size([4, 3, 32, 32]) 4 - 34 : torch.Size([4, 3, 32, 32]) 4 - 35 : torch.Size([4, 3, 32, 32]) 4 - 36 : torch.Size([4, 3, 32, 32]) 4 - 37 : torch.Size([4, 3, 32, 32]) 4 - 38 : torch.Size([4, 3, 32, 32]) 4 - 39 : torch.Size([4, 3, 32, 32]) 4 - 40 : torch.Size([4, 3, 32, 32]) 4 - 41 : torch.Size([4, 3, 32, 32]) 4 - 42 : torch.Size([4, 3, 32, 32]) 4 - 43 : torch.Size([4, 3, 32, 32]) 4 - 44 : torch.Size([4, 3, 32, 32]) 4 - 45 : torch.Size([4, 3, 32, 32]) 4 - 46 : torch.Size([4, 3, 32, 32]) Set next loader: 4 ```

Anyway, I agree that separating epoch, iteration and epoch_length could be a interesting feature to have.

pmneila commented 3 years ago

Thank you. That will do for now.

sparkingdark commented 3 years ago

Hey @vfdev-5 I think this can be a good warm up issue to solve to implement a dynamic epoch length variable.

A approach can be make a function which take the initial epoch size,batch param and can we consider memory available , now if we can do diff = total_epoch - initial_epoch and assign diff to epoch length ,but it is not clear to me can we start the training again or can use a custom event to bind this.