discovery-unicamp / Minerva

Minerva is a framework for training machine learning models for researchers.
https://discovery-unicamp.github.io/Minerva/
MIT License
3 stars 7 forks source link

Models #4

Open GabrielBG0 opened 8 months ago

GabrielBG0 commented 8 months ago

Collection of models

Features to be Implemented

otavioon commented 7 months ago

Sugestão para implementação de modelos

Modelos de aprendizado profundo normalmente são escritos em Frameworks especializados em auto-diferenciação como Pytorch, Tensorflow etc. Dos diversos frameworks, sugiro o Pytorch Lightning, que é uma extensão do Pytorch que facilita a implementação de modelos de aprendizado profundo.

Motivação para o Pytorch Lightning

O Pytorch Lightning é um framework que abstrai a implementação de modelos de aprendizado profundo, facilitando a reprodutibilidade e a depuração. Ele fornece uma API bem definida para a implementação, treinamento e teste de modelos de aprendizado profundo; gerenciamento de logs e checkpoints; facilidades para treinamento distribuído e heterogêneo; e facilidades para treinamento em clusters de GPUs (e.g. SLURM).

APIs do Pytorch Lightning

Os modelos de aprendizado profundo são implementados através de uma classe que herda de pl.LightningModule. Essa classe, minimamente, deve implementar os métodos __init__, forward, training_step, test_step e configure_optimizers.

Exemplo de modelos

A implementação de um modelo de aprendizado profundo em Pytorch Lightning é mostrada a seguir. Neste exemplo, o modelo é uma rede neural Multilayer Perceptron (MLP) com uma camada oculta.

import torch
import torch.nn as nn
import lightning as L

class MLP(L.LightningModule):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.relu = nn.ReLU()

    def forward(self, x):
        # Realiza a inferência do modelo
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

    def training_step(self, batch, batch_idx):
        # Define o que acontece em cada passo de treinamento
        x, y = batch
        y_hat = self(x)
        loss = nn.CrossEntropyLoss()(y_hat, y)
        return loss

    def configure_optimizers(self):
        # Define o otimizador do modelo.
        # Os parametros do modelo são acessados através do atributo self.parameters()
        return torch.optim.Adam(self.parameters(), lr=1e-3)

Treinamento

Uma vez definido o modelo, o treinamento é feito através da classe pl.Trainer. Esta classe é responsável por gerenciar o treinamento, validação e teste do modelo. O treinamento é feito através do método trainer.fit(model, ...), onde model é o modelo a ser treinado. O código a seguir mostra um exemplo de treinamento de um modelo.

from lightning import Trainer

...
model = MLP(input_size, hidden_size, output_size)
trainer = Trainer(max_epochs=10)
trainer.fit(model, train_dataloader, val_dataloader)

O método trainer.fit abstrai o laço de treinamento comumente observado em códigos Pytorch. Este método chama o método configure_optimizers do modelo para definir o otimizador e, em seguida, chama os métodos training_step e validation_step do modelo para realizar o treinamento e a validação, respectivamente. Estes métodos são chamados para cada lote de dados (batch) do conjunto de treinamento e validação.

Extensão por meio de callbacks

A extensão do funcionamento do método trainer.fit é feita através de callbacks. Os callbacks são método que são chamados em pontos específicos do treinamento, como no início e no final de cada época (on_train_epoch_start e on_train_epoch_end, por exemplo) e possuem uma API bem definida. O Pytorch Lightning fornece alguns callbacks que podem ser usados diretamente ou estendidos para implementar funcionalidades específicas.

Logging

O Lighning oferece uma ferramenta para gerenciamento de logs que permite o registro de métricas e hiperparâmetros do modelo. Por exemplo, métricas de treinamento e validação podem ser registradas através do método self.log do modelo. O código a seguir mostra um exemplo de registro de métricas.

def training_step(self, batch, batch_idx):
    ...
    self.log('train_loss', loss)
    return loss

Os logs são armazenados em arquivos no formato Tensorboard e podem ser visualizados através do comando tensorboard --logdir=logs. Entretanto, os logs podem ser configurados para serem armazenados em outros formatos, como um simples CSV, MLFlow, Neptune, Comet e WandB.

Sugestão de implementação

Como descrito acima, o Pytorch Lightning abstrai a implementação de modelos de aprendizado profundo, além de facilitar a reprodutibilidade, organização e a depuração. Entretanto, a implementação de modelos não segue uma receita de bolo e depende do problema a ser resolvido. Desta forma, além de sugerir a implementação dos nossos modelos usando Pytorch Lightning, sugiro algumas diretrizes para a implementação e uso dos modelos:

def training_step(self, batch, batch_idx):
    ...
    self.log('train_loss', loss, on_epoch=True, on_step=False, prog_bar=True, logger=True)
    return loss

from lightning.pytorch.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    mode='min',
    monitor='val_loss',
    dirpath='checkpoints/',
    filename='model-{epoch:02d}-{val_loss:.2f}',
)

trainer = Trainer(
    callbacks=[checkpoint_callback],
    ...
)