zettelkasten

Search IconIcon to open search
Dark ModeDark Mode

Template for Pytorch Lightning

Date: 26 Dec 2021

#post

This post originally appeared on Blog 2.0


A useful template for Pytorch Lightning module

class LitModel(pl.LightningModule):

    def __init__(self, ...):
        # call this to save args to the checkpoint
        self.save_hyperparameters()

    def forward(self, ...):
	# the forward step here

    def training_step(self, ...):
		# what should have been in the training loop. pl automatically throws in batch and batch_idx. return the loss!
        # in GAN, pl will call this function with an `optimizer_idx` parameter
        # if optimizer_idx == 0:
            # do training_step with encoder
        # if optimizer_idx == 1:
            # do training_step with decoder

		# use self.log() to log stuff to tensorboard
        # self.log("my_metric", x, on_step=True, on_epoch=True, prog_bar=True, logger=True) # a scalar
        # self.log("performance", {"acc": acc, "recall": recall}, on_step=True, on_epoch=True, prog_bar=True, logger=True) # or a dic

        # return the loss directly, or return a dic that includes 'loss' as the key. 
        # return {"loss": loss, "hiddens": None}

    def training_step_end(self, ...):
	    # runs after each training step

    def training_epoch_end(self, ...):
	    # runs after each epoch

    def validation_step(self, ...):

    def validation_step_end(self, ...):

    def validation_epoch_end(self, ...):

    def test_step(self, ...):

    def test_step_end(self, ...):

    def test_epoch_end(self, ...):

    def configure_optimizers(self, ...):
	    # define pytorch optimizers here. return [list of optimizers], [list of LR schedulers]

    def any_extra_hook(self, ...):
	    # extra stuff

And DataModule

class DataModule(LightningDataModule):

    def __init__(self, ...):

    # these funcs can also be placed directly inside a LightningModule
    def train_dataloader(self):
        return DataLoader(self.train_dataset)

    def val_dataloader(self):
        return [DataLoader(self.val_dataset_1), DataLoader(self.val_dataset_2)]

    def test_dataloader(self):
        return DataLoader(self.test_dataset)

    def predict_dataloader(self):
        return DataLoader(self.predict_dataset)