Open GabrielBG0 opened 8 months ago
Modelos de aprendizado profundo normalmente são escritos em Frameworks especializados em auto-diferenciação como Pytorch, Tensorflow etc. Dos diversos frameworks, sugiro o Pytorch Lightning, que é uma extensão do Pytorch que facilita a implementação de modelos de aprendizado profundo.
O Pytorch Lightning é um framework que abstrai a implementação de modelos de aprendizado profundo, facilitando a reprodutibilidade e a depuração. Ele fornece uma API bem definida para a implementação, treinamento e teste de modelos de aprendizado profundo; gerenciamento de logs e checkpoints; facilidades para treinamento distribuído e heterogêneo; e facilidades para treinamento em clusters de GPUs (e.g. SLURM).
Os modelos de aprendizado profundo são implementados através de uma classe que herda de pl.LightningModule
. Essa classe, minimamente, deve implementar os métodos __init__
, forward
, training_step
, test_step
e configure_optimizers
.
__init__
é usado para definir os componentes do modelo, como camadas, funções de ativação, etc. Os componentes podem ser definidos diretamente no método __init__
ou recebidos como parâmetros do construtor da classe (modularização). Além disso, estes componentes podem ser objetos Pytorch ou objetos do Pytorch Lightning. Por fim, também são definidos (ou recebidos) os hiperparâmetros do modelo, como taxa de aprendizado, número de camadas, etc. forward
é usado para definir o fluxo de dados no modelo e é chamado quando o modelo é usado para inferência.training_step
, validation_step
e test_step
são usados para definir o que acontece em cada passo do treinamento, validação e teste, respectivamente. Estes métodos, por padrão, recebem um lote de dados (batch) e retornam as métricas calculadas no passo.configure_optimizers
é usado para definir o otimizador do modelo.A implementação de um modelo de aprendizado profundo em Pytorch Lightning é mostrada a seguir. Neste exemplo, o modelo é uma rede neural Multilayer Perceptron (MLP) com uma camada oculta.
import torch
import torch.nn as nn
import lightning as L
class MLP(L.LightningModule):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, output_size)
self.relu = nn.ReLU()
def forward(self, x):
# Realiza a inferência do modelo
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
def training_step(self, batch, batch_idx):
# Define o que acontece em cada passo de treinamento
x, y = batch
y_hat = self(x)
loss = nn.CrossEntropyLoss()(y_hat, y)
return loss
def configure_optimizers(self):
# Define o otimizador do modelo.
# Os parametros do modelo são acessados através do atributo self.parameters()
return torch.optim.Adam(self.parameters(), lr=1e-3)
Uma vez definido o modelo, o treinamento é feito através da classe pl.Trainer
. Esta classe é responsável por gerenciar o treinamento, validação e teste do modelo. O treinamento é feito através do método trainer.fit(model, ...)
, onde model
é o modelo a ser treinado.
O código a seguir mostra um exemplo de treinamento de um modelo.
from lightning import Trainer
...
model = MLP(input_size, hidden_size, output_size)
trainer = Trainer(max_epochs=10)
trainer.fit(model, train_dataloader, val_dataloader)
O método trainer.fit
abstrai o laço de treinamento comumente observado em códigos Pytorch. Este método chama o método configure_optimizers
do modelo para definir o otimizador e, em seguida, chama os métodos training_step
e validation_step
do modelo para realizar o treinamento e a validação, respectivamente. Estes métodos são chamados para cada lote de dados (batch) do conjunto de treinamento e validação.
A extensão do funcionamento do método trainer.fit
é feita através de callbacks.
Os callbacks são método que são chamados em pontos específicos do treinamento, como no início e no final de cada época (on_train_epoch_start
e on_train_epoch_end
, por exemplo) e possuem uma API bem definida. O Pytorch Lightning fornece alguns callbacks que podem ser usados diretamente ou estendidos para implementar funcionalidades específicas.
O Lighning oferece uma ferramenta para gerenciamento de logs que permite o registro de métricas e hiperparâmetros do modelo. Por exemplo, métricas de treinamento e validação podem ser registradas através do método self.log
do modelo. O código a seguir mostra um exemplo de registro de métricas.
def training_step(self, batch, batch_idx):
...
self.log('train_loss', loss)
return loss
Os logs são armazenados em arquivos no formato Tensorboard e podem ser visualizados através do comando tensorboard --logdir=logs
. Entretanto, os logs podem ser configurados para serem armazenados em outros formatos, como um simples CSV, MLFlow, Neptune, Comet e WandB.
Como descrito acima, o Pytorch Lightning abstrai a implementação de modelos de aprendizado profundo, além de facilitar a reprodutibilidade, organização e a depuração. Entretanto, a implementação de modelos não segue uma receita de bolo e depende do problema a ser resolvido. Desta forma, além de sugerir a implementação dos nossos modelos usando Pytorch Lightning, sugiro algumas diretrizes para a implementação e uso dos modelos:
train_loss
e val_loss
, dentro da função training_step
e validation_step
do modelo, respectivamente. Estas perdas devem ser registradas por meio do método self.log
do modelo, por época (usando o parâmetro on_epoch=True
). Abaixo a linha de código para registro da perda de treino. Realizando o registro desta forma permite facilmente visualizar as curvas de treinamento e validação no Tensorboard, ou externamente.def training_step(self, batch, batch_idx):
...
self.log('train_loss', loss, on_epoch=True, on_step=False, prog_bar=True, logger=True)
return loss
O método forward
do modelo deve retornar as predições do modelo ou os vetores latentens, dependendo do modelo. Por exemplo, no caso do VAE, o método forward
deve retornar os vetores latentes, enquanto que, no caso do MLP, o método forward
deve retornar as predições do modelo. Isso facilita a composição de modelos, como no caso do VAE-GAN, onde o VAE é usado para gerar os vetores latentes que são usados como entrada do GAN.
Sempre que possível optar por criar modelos modulares, onde cada componente do modelo é definido em uma classe separada e a composiçao do modelo é feita por meio de uma classe que herda de pl.LightningModule
, através do método __init__
. Por exemplo, no caso do VAE, podemos definir uma classe Encoder
e uma classe Decoder
que herdam de pl.LightningModule
e, em seguida, definir uma classe VAE
que herda de pl.LightningModule
e que recebe os objetos Encoder
e Decoder
como parâmetros do construtor. Isso facilita a composição de modelos e a reutilização de componentes.
Sempre que possível, criar um builder
que constrói o modelo. O builder
é uma função que recebe os hiperparâmetros do modelo e retorna uma instância do modelo. Por exemplo, no caso do VAE, podemos definir uma função build_vae
que recebe os hiperparâmetros do modelo e retorna uma instância do modelo. Isso facilita a criação de modelos e a definição de modelos conhecidos na literatura. Além disso, fica fácil para embutir em arquivos de configuração, como YAML, JSON, TOML, etc, para realizar hyperparameter tuning e criar um Model Zoo.
Por padrão o Pytorch Lightning não salva os pesos do modelo. Entretanto, podemos usar o callback ModelCheckpoint
para salvar os pesos do modelo. O callback ModelCheckpoint
salva os pesos do modelo sempre que a métrica de validação (val_loss
) melhora. O código a seguir mostra um exemplo de uso do callback ModelCheckpoint
.
from lightning.pytorch.callbacks import ModelCheckpoint
checkpoint_callback = ModelCheckpoint(
mode='min',
monitor='val_loss',
dirpath='checkpoints/',
filename='model-{epoch:02d}-{val_loss:.2f}',
)
trainer = Trainer(
callbacks=[checkpoint_callback],
...
)
Collection of models
Features to be Implemented