jl749 / the-clean-transformer

pytorch-lightning과 wandb로 깔끔하게 구현해보는 트랜스포머
0 stars 0 forks source link

torch_lightning #9

Open jl749 opened 2 years ago

jl749 commented 2 years ago
from pytorch_lightning import LightningModule

class Transformer(LightningModule):
        self.save_hyperparameters()

        self.hello_there = torch.nn.Module()

model = Transformer()
model.parameters()  # return all nn.Module typed parameters
jl749 commented 2 years ago

LightningModule example

import pytorch_lightning as pl

class LitModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(28 * 28, 10)

    def forward(self, x):  # Use for inference only
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_idx):  # the complete training loop
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        return loss

    def configure_optimizers(self):  # define optimizers and LR schedulers
        return torch.optim.Adam(self.parameters(), lr=0.02)

https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html

jl749 commented 2 years ago

lightning on_train methods

    def on_train_start(self):
        # many deep transformer models are initialised with so-called "Xavier initialisation"
        # refer to: https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf
        for param in tqdm(self.parameters(), desc="initialising weights..."):
            if param.dim() > 1:
                torch.nn.init.xavier_uniform_(param)

    def on_train_batch_end(self, outputs: dict, *args, **kwargs):
        self.log("Train/Loss", outputs['loss'])