pytorch / ignite

High-level library to help with training and evaluating neural networks in PyTorch flexibly and transparently.
https://pytorch-ignite.ai
BSD 3-Clause "New" or "Revised" License
4.52k stars 615 forks source link

Higher API for training #912

Open Minyus opened 4 years ago

Minyus commented 4 years ago

πŸš€ Feature

Higher API for training will improve usability of Ignite as Keras did.

My implementation here could be used as the base code.

sdesrozis commented 4 years ago

Thank you for this PR and your help !!

Let’s discuss and design now about this very nice feature 😊

sdesrozis commented 4 years ago

I really enjoy the idea to have open (and clear) core tools (engine, handlers, etc.) and an api helper for compose app mixing these core tools πŸ‘πŸ»

As we discussed in slack channel, we could think about very high level feature including graphical frontend. Although first step is focused on trainer 😊

Could you use your api on some examples from ignite/examples ? Thank you again!

vfdev-5 commented 4 years ago

In my opinion, it would be interesting to create something like


class Trainer(Engine):

     def __init__(self, *agrs, **kwags):
         ...

     def train_step(self, ...):
          ...

     etc

cc @ericspod

ericspod commented 4 years ago

I have simplified what I use in my code base which @vfdev-5 is referencing here.

The idea is to extend the Engine class with a BaseEngine (a different name would be good) which adds methods for converting between numpy arrays and tensors, wrapper methods around doing forward passes on network and loss functions, and for simplifying inference with a network. The Trainer and Evaluator classes inherit from this and provide their own methods for implementing a training loop or using an external callable as in the base Engine class, evaluating the network on all the data from a dataset and returning the mean loss, and others.

This is a much more opinionated set of classes than what is provided with Ignite's Engine. I think Ignite is correct in providing a base level of general purpose mechanisms rather than a more involved concept of what should be done. These classes for example work best with a training regime where the batch contains inputs to the network and to a loss function and the network produces values for the loss function as well. Not all training schemes would have these. Adding a high level API might be useful to many but so long as it's optional Ignite can stay generic.

These types use thread locks in places since I use a separate thread in Jupyter notebooks for training and the main thread for plotting current progress, which is useful in places like colab where tensorboard and visdom aren't available. It would be better to change Engine to have it's own internal lock which is acquired when a step begins before ITERATION_STARTED events are fired, this ensure synchronization with events used to summarize or log the current state. I have a logging class implemented as an event handler which keeps a list of losses and evaluation values, this might not have the current iteration's results yet when queried at a time when the values stored in the State do contain current data.

These types also use the State object to store a lot more than normal. The output from the iteration callable is still stored but inputs to the network, outputs, and loss function outputs are stored as separate variables. There are other variables added by the logging handler for storing the log values. I am not sure if using State in this way constitutes feature abuse in the way Ignite was intended it to be used but it works for me so it's a topic for a design discussion.

vfdev-5 commented 3 years ago

Here is another idea of a higher level API (similar to others) :

from argus import Model
from argus.callbacks import MonitorCheckpoint, EarlyStopping, ReduceLROnPlateau

class TimmModel(Model):
    nn_module = timm.create_model

if __name__ == "__main__":

    model = TimmModel(params)

    callbacks = [
        MonitorCheckpoint(dir_path='mnist', monitor='val_accuracy', max_saves=3),
        EarlyStopping(monitor='val_accuracy', patience=9),
        ReduceLROnPlateau(monitor='val_accuracy', factor=0.5, patience=3)
    ]

    model.fit(train_loader,
              val_loader=val_loader,
              num_epochs=50,
              metrics=['accuracy'],
              callbacks=callbacks,
              metrics_on_train=True)

and there is something like our Engine is set up behind model.fit.

vfdev-5 commented 3 years ago

More implementations to explore: