pytorch / ignite

High-level library to help with training and evaluating neural networks in PyTorch flexibly and transparently.
BSD 3-Clause "New" or "Revised" License
4.51k stars 612 forks source link

Run metrics calculation in a background process #3244

Closed H4dr1en closed 4 months ago

H4dr1en commented 5 months ago

❓ Questions/Help/Support

I am training neural networks using pytorch-ignite for a computer vision task. Currently, calculating metrics during validation takes a significant part of the validation time and I want to optimise this, so that most of the time is spent in GPU inference operations

What I have in mind:

  1. Move the computation of the metrics from the validation step to a background process/queue
  2. Schedule the computation of the metrics after each iteration by sending the results of the inference to the metrics queue

This would work out of the box but I also use several handlers at the end of each validation epochs (logging/checkpoint/lr scheduler/early stopping) that depend on the metrics, so I need to have a sync at the end of the epoch to wait for the metrics computation to finish before triggering the various handlers.

As an alternative, I could rework each metric to make them faster, potentially using GPU, but this is also a long process as I would need to do it for each current and new metrics - I'd rather have them computed in the background, so that I don't have to care about the efficiency of the implementation

My questions are:

  1. What would be the cleanest way to implement this using pytorch-ignite?
  2. If pytorch-ignite doesn't provide a nice interface for this use case, would it make sense to extend the library to support it?
  3. Is there a blind spot that I am missing - is there a better way to deal with this situation?
vfdev-5 commented 5 months ago

@H4dr1en thanks for asking this interesting question!

The point 1 "Move the computation of the metrics from the validation step to a background process/queue" looks interesting, but here I mostly wonder about computing resources usage. I'm thinking about the following:

The point 2 "Schedule the computation of the metrics after each iteration by sending the results of the inference to the metrics queue" I can read as computing the metrics on the training dataset using constantly updated model. I'm not very sure whether this is a good point. Final value wont correspond to the latest model...

As for scheduling a new process for metrics computation, let me think a bit what would be an implementation with ignite. and then we may follow your point "If pytorch-ignite doesn't provide a nice interface for this use case, would it make sense to extend the library to support it?".

Currently, calculating metrics during validation takes a significant part of the validation time and I want to optimise this, so that most of the time is spent in GPU inference operations

Thinking about this statement, can't be possible to run validation with a larger intervals such that training time can be longer. For example

- @trainer.on(Events.EPOCH_COMPLETED)
+ @trainer.on(Events.EPOCH_COMPLETED(every=100))
def run_validation():
H4dr1en commented 5 months ago

Hi @vfdev-5 ,

Thanks for the fast answer! I think there is small confusion, I don't want to move the whole validation logic in a background process, only the computation of the metrics. So both training and validation step would still run in the main process. The only part that I want to differ to a background process to unblock the training/validation is the computation of the metrics:

def compute_metrics(validation_engine):
    # Current state: long running, CPU bound
    engine.state.output = compute_metrics(valdation_engine.state.output)

    # My idea: send to background to unblock the rest of the program
    # But I don't know how it would play with updating of the metrics (eg. RunningAverage)

Then at the end of each validation epoch, I would wait and collect the metrics so that handlers depending on them can run:

@validator.on(Events.EPOCH_COMPLETED)  # Must be the first of these events because others might depend on state.metrics
def aggregate_iteration_metrics(validation_engine):

    # Here I would pull from the result queue
    metric_results = metrics_results_queue.get()

    # And somehow integrate them/trigger the metrics like RunningAverage
    RunningAverage.step(metric_results)   # I have no idea on how to do this at this point

Thinking about this statement, can't be possible to run validation with a larger intervals such that training time can be longer. For example

This would for sure help, but at the cost of more sparse validation curves and need for different LRScheduler and EarlyStopping values

vfdev-5 commented 5 months ago

So, to confirm, ideally, engine.state.output = compute_metrics(valdation_engine.state.output) should run in another process and at metric_results = metrics_results_queue.get() we join the process and get all results ? Each iteration will submit a new task and on join call will have to wait once all tasks are done.

H4dr1en commented 5 months ago

Yes exactly 👍

vfdev-5 commented 5 months ago

@H4dr1en here is a prototype of running handlers in a process pool:

import time

import torch
from ignite.engine import Engine, Events
from ignite.utils import setup_logger, logging

import torch.multiprocessing as mp

def long_running_computation(data):
    m = data["c"]
    v = m.sum()
    for _ in range(10000):
        v = v - m.mean()
        v = m.log_softmax(dim=1).sum() + v

    return v + data["a"] + data["b"]

def run():

    eval_data = range(10)

    for with_mp in [True, False]:

        def eval_step(engine, batch):
            # forward pass latency
            print(f"{engine.state.epoch} / {engine.state.max_epochs} | {engine.state.iteration} - batch: {batch}", flush=True)
            return {
                "a": torch.tensor(engine.state.iteration, dtype=torch.float32),
                "b": torch.rand(()).item(),
                "c": torch.rand(128, 5000),

        validator = Engine(eval_step)

        # pick a reasonable value of workers:
        if with_mp:
            pool = mp.Pool(processes=2) = []

        def do_long_running_computation():
            if with_mp:
                    pool.apply_async(long_running_computation, (validator.state.output,))

        def gather_results():
            if with_mp:
       = [
                    r.get() for r in
            validator.state.metrics["abc"] = sum(

        start = time.time()
        elapsed = time.time() - start
        print("Elapsed time:", elapsed)
        if with_mp:

if __name__ == "__main__":


python -u

1 / 1 | 1 - batch: 0
1 / 1 | 2 - batch: 1
1 / 1 | 3 - batch: 2
1 / 1 | 4 - batch: 3
1 / 1 | 5 - batch: 4
1 / 1 | 6 - batch: 5
1 / 1 | 7 - batch: 6
1 / 1 | 8 - batch: 7
1 / 1 | 9 - batch: 8
1 / 1 | 10 - batch: 9
Elapsed time: 22.257242918014526
1 / 1 | 1 - batch: 0
1 / 1 | 2 - batch: 1
1 / 1 | 3 - batch: 2
1 / 1 | 4 - batch: 3
1 / 1 | 5 - batch: 4
1 / 1 | 6 - batch: 5
1 / 1 | 7 - batch: 6
1 / 1 | 8 - batch: 7
1 / 1 | 9 - batch: 8
1 / 1 | 10 - batch: 9
Elapsed time: 26.21021580696106

Number of pool processes should be taken carefully as pytorch ops could be multi-threaded and all that can lead to perf degradation if using too much processes. Let me know if this is what you were thinking of?

In case we would like to add something similar to ignite API, we have to think carefully about the public API...

H4dr1en commented 4 months ago

Hi @vfdev-5 , thanks for this super example 👍 yes it covers most of my needs!

I have some questions:

vfdev-5 commented 4 months ago

What is the, is it something internal? How does it work?

It is just a user-defined list manually created on an Engine.state: = [], not something ignite internal.

More specifically: here we are writing the metric values directly to the engine.state.metrics, would the metrics/logger properly pick up the values of each iteration? How to ensure it?

Yes, loggers if configured to log metrics are taking values from engine.state.metrics:

To ensure that loggers picks the value, you have to add its handler after gather_results handler. While debugging, you can check that handlers on the event are set in the desired order, for example:

H4dr1en commented 4 months ago

That's perfect 💯 Closing the issue for now, this should do it 👍