pasqal-io / qadence

Digital-analog quantum programming interface
https://pasqal-io.github.io/qadence/latest/
Apache License 2.0
66 stars 21 forks source link

[Refactoring] Alternative ways of defining tensorboard metrics #202

Closed atiyo closed 2 weeks ago

atiyo commented 9 months ago

At the moment, train accepts a function write_tensorboard(writer, loss, metrics, iteration), where metrics is returned from a custom loss function.

However, we might want to log non-training related metrics, e.g. logging a plot, or logging a mean squared error to a known benchmark. These sorts of metrics are a) potentially expensive to calculate at every training step and b) conceptually independent to the training loop.

It would be cool if there was an alternative way of logging custom metrics etc. to tensorboard without having to do so via the loss function.

nmheim commented 9 months ago

I think the most general way of implementing this would be to redefine our train_with_grad function and do:

    # outer epoch loop
    for iteration in progress.track(range(init_iter, init_iter + config.max_iter)):
        try:
            # in case there is not data needed by the model
            # this is the case, for example, of quantum models
            # which do not have classical input data (e.g. chemistry)
            if dataloader is None:
                loss, metrics = optimize_step(model, optimizer, loss_fn, None)
                loss = loss.item()

            elif isinstance(dataloader, (DictDataLoader, DataLoader)):
                data = data_to_device(next(dl_iter), device)  # type: ignore[arg-type]
                loss, metrics = optimize_step(model, optimizer, loss_fn, data)

            else:
                raise NotImplementedError(
                    f"Unsupported dataloader type: {type(dataloader)}. "
                    "You can use e.g. `qadence.ml_tools.to_dataloader` to build a dataloader."
                )

            iteration_callback()

        except KeyboardInterrupt:
            print("Terminating training gracefully after the current iteration.")
            break

    # Final writing and checkpointing
    final_callback()

    return model, optimizer

instead of whats currently being done with "hardcoded" functions. The hardcoded functions should be called in the default iteration_callback/final_callback and we need a wait to nicely construct one callback from a list of callbacks.

smitchaudhary commented 1 month ago

@DanieleCucurachi FYI!

chMoussa commented 2 weeks ago

Closes with #533