Template for Pytorch Lightning
#post
This post originally appeared on Blog 2.0
A useful template for Pytorch Lightning module
class LitModel(pl.LightningModule):
def __init__(self, ...):
# call this to save args to the checkpoint
self.save_hyperparameters()
def forward(self, ...):
# the forward step here
def training_step(self, ...):
# what should have been in the training loop. pl automatically throws in batch and batch_idx. return the loss!
# in GAN, pl will call this function with an `optimizer_idx` parameter
# if optimizer_idx == 0:
# do training_step with encoder
# if optimizer_idx == 1:
# do training_step with decoder
# use self.log() to log stuff to tensorboard
# self.log("my_metric", x, on_step=True, on_epoch=True, prog_bar=True, logger=True) # a scalar
# self.log("performance", {"acc": acc, "recall": recall}, on_step=True, on_epoch=True, prog_bar=True, logger=True) # or a dic
# return the loss directly, or return a dic that includes 'loss' as the key.
# return {"loss": loss, "hiddens": None}
def training_step_end(self, ...):
# runs after each training step
def training_epoch_end(self, ...):
# runs after each epoch
def validation_step(self, ...):
def validation_step_end(self, ...):
def validation_epoch_end(self, ...):
def test_step(self, ...):
def test_step_end(self, ...):
def test_epoch_end(self, ...):
def configure_optimizers(self, ...):
# define pytorch optimizers here. return [list of optimizers], [list of LR schedulers]
def any_extra_hook(self, ...):
# extra stuff
And DataModule
class DataModule(LightningDataModule):
def __init__(self, ...):
# these funcs can also be placed directly inside a LightningModule
def train_dataloader(self):
return DataLoader(self.train_dataset)
def val_dataloader(self):
return [DataLoader(self.val_dataset_1), DataLoader(self.val_dataset_2)]
def test_dataloader(self):
return DataLoader(self.test_dataset)
def predict_dataloader(self):
return DataLoader(self.predict_dataset)