Open jl749 opened 2 years ago
import pytorch_lightning as pl
class LitModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.l1 = nn.Linear(28 * 28, 10)
def forward(self, x): # Use for inference only
return torch.relu(self.l1(x.view(x.size(0), -1)))
def training_step(self, batch, batch_idx): # the complete training loop
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return loss
def configure_optimizers(self): # define optimizers and LR schedulers
return torch.optim.Adam(self.parameters(), lr=0.02)
https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html
on_train
methods def on_train_start(self):
# many deep transformer models are initialised with so-called "Xavier initialisation"
# refer to: https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf
for param in tqdm(self.parameters(), desc="initialising weights..."):
if param.dim() > 1:
torch.nn.init.xavier_uniform_(param)
def on_train_batch_end(self, outputs: dict, *args, **kwargs):
self.log("Train/Loss", outputs['loss'])