lightning_template.models.base
==============================

.. py:module:: lightning_template.models.base


Classes
-------

.. autoapisummary::

   lightning_template.models.base.LightningModule


Module Contents
---------------

.. py:class:: LightningModule(model: Optional[torch.nn.Module] = None, ckpt_path: Optional[Union[str, List[str]]] = None, finetune_cfg: Optional[Union[str, List[str], Mapping]] = None, evaluator_cfg: Mapping = None, evaluator_as_submodule: bool = True, loss_weights=None, predict_tasks: Optional[List[str]] = None, predict_path: str = 'prediction', *args, **kwargs)

   Bases: :py:obj:`lightning_template.utils.mixin.SplitNameMixin`, :py:obj:`lightning.pytorch.LightningModule`


   Base class for all neural network modules.

   Your models should also subclass this class.

   Modules can also contain other Modules, allowing them to be nested in
   a tree structure. You can assign the submodules as regular attributes::

       import torch.nn as nn
       import torch.nn.functional as F

       class Model(nn.Module):
           def __init__(self) -> None:
               super().__init__()
               self.conv1 = nn.Conv2d(1, 20, 5)
               self.conv2 = nn.Conv2d(20, 20, 5)

           def forward(self, x):
               x = F.relu(self.conv1(x))
               return F.relu(self.conv2(x))

   Submodules assigned in this way will be registered, and will also have their
   parameters converted when you call :meth:`to`, etc.

   .. note::
       As per the example above, an ``__init__()`` call to the parent class
       must be made before assignment on the child.

   :ivar training: Boolean represents whether this module is in training or
                   evaluation mode.
   :vartype training: bool


   .. py:attribute:: model
      :value: None



   .. py:attribute:: ckpt_path
      :value: None



   .. py:attribute:: evaluators


   .. py:attribute:: loss_weights
      :value: None



   .. py:attribute:: evaluator_cfg


   .. py:attribute:: evaluate_as_submodule
      :value: True



   .. py:attribute:: finetune_cfg
      :value: None



   .. py:attribute:: predict_tasks
      :value: None



   .. py:attribute:: predict_path
      :value: 'prediction'



   .. py:attribute:: lr
      :value: None



   .. py:attribute:: automatic_lr_schedule
      :value: True



   .. py:attribute:: manual_step_scedulers
      :value: []



   .. py:attribute:: model_not_configured
      :value: True



   .. py:method:: build_model()


   .. py:method:: configure_model()

      Hook to create modules in a strategy and precision aware context.

      This is particularly useful for when using sharded strategies (FSDP and DeepSpeed), where we'd like to shard
      the model instantly to save memory and initialization time.
      For non-sharded strategies, you can choose to override this hook or to initialize your model under the
      :meth:`~lightning.pytorch.trainer.trainer.Trainer.init_module` context manager.

      This hook is called during each of fit/val/test/predict stages in the same process, so ensure that
      implementation of this hook is **idempotent**, i.e., after the first time the hook is called, subsequent calls
      to it should be a no-op.




   .. py:method:: recursive_parse_modules(module)
      :staticmethod:



   .. py:method:: _build_evaluator(split)


   .. py:method:: setup(stage=None)

      Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you
      need to build models dynamically or adjust something about them. This hook is called on every process when
      using DDP.

      :param stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'``

      Example::

          class LitModel(...):
              def __init__(self):
                  self.l1 = None

              def prepare_data(self):
                  download_data()
                  tokenize()

                  # don't do this
                  self.something = else

              def setup(self, stage):
                  data = load_data(...)
                  self.l1 = nn.Linear(28, data.num_classes)




   .. py:method:: on_fit_start()

      Called at the very beginning of fit.

      If on DDP it is called on every process




   .. py:method:: optimizer_step(*args, **kwargs) -> None

      Override this method to adjust the default way the :class:`~lightning.pytorch.trainer.trainer.Trainer` calls
      the optimizer.

      By default, Lightning calls ``step()`` and ``zero_grad()`` as shown in the example.
      This method (and ``zero_grad()``) won't be called during the accumulation phase when
      ``Trainer(accumulate_grad_batches != 1)``. Overriding this hook has no benefit with manual optimization.

      :param epoch: Current epoch
      :param batch_idx: Index of current batch
      :param optimizer: A PyTorch optimizer
      :param optimizer_closure: The optimizer closure. This closure must be executed as it includes the
                                calls to ``training_step()``, ``optimizer.zero_grad()``, and ``backward()``.

      Examples::

          def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure):
              # Add your custom logic to run directly before `optimizer.step()`

              optimizer.step(closure=optimizer_closure)

              # Add your custom logic to run directly after `optimizer.step()`




   .. py:method:: flatten_dict(log_dict, prefix, sep='/')
      :staticmethod:



   .. py:method:: forward(batch, *args, **kwargs)

      Same as :meth:`torch.nn.Module.forward`.

      :param \*args: Whatever you decide to pass into the forward method.
      :param \*\*kwargs: Keyword arguments are also possible.

      :returns: Your model's output



   .. py:method:: _loss_step(*args, output, **kwargs)


   .. py:method:: loss_step(*args, use_loss_weight=True, **kwargs)


   .. py:method:: update_evaluator(evaluator, *args, metrics, **kwargs)


   .. py:method:: _metric_step(*args, output, **kwargs)


   .. py:method:: metric_step(*args, dataloader_idx=None, split, **kwargs)


   .. py:method:: _compute_evaluator(evaluator, *args, **kwargs)


   .. py:method:: compute_evaluator(evaluator, dataloader_idx=None, *args, **kwargs)


   .. py:method:: on_metric_epoch_end(*args, split, **kwargs)


   .. py:method:: forward_step(*args, split, **kwargs)


   .. py:method:: on_forward_epoch_end(*args, split, **kwargs)


   .. py:method:: training_step(batch, batch_idx, dataloader_idx=None, *args, **kwargs)

      Here you compute and return the training loss and some additional metrics for e.g. the progress bar or
      logger.

      :param batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`.
      :param batch_idx: The index of this batch.
      :param dataloader_idx: The index of the dataloader that produced this batch.
                             (only if multiple dataloaders used)

      :returns:

                - :class:`~torch.Tensor` - The loss tensor
                - ``dict`` - A dictionary which can include any keys, but must include the key ``'loss'`` in the case of
                  automatic optimization.
                - ``None`` - In automatic optimization, this will skip to the next batch (but is not supported for
                  multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning
                  the loss is not required.

      In this step you'd normally do the forward pass and calculate the loss for a batch.
      You can also do fancier things like multiple forward passes or something model specific.

      Example::

          def training_step(self, batch, batch_idx):
              x, y, z = batch
              out = self.encoder(x)
              loss = self.loss(out, x)
              return loss

      To use multiple optimizers, you can switch to 'manual optimization' and control their stepping:

      .. code-block:: python

          def __init__(self):
              super().__init__()
              self.automatic_optimization = False


          # Multiple optimizers (e.g.: GANs)
          def training_step(self, batch, batch_idx):
              opt1, opt2 = self.optimizers()

              # do training_step with encoder
              ...
              opt1.step()
              # do training_step with decoder
              ...
              opt2.step()

      .. note::

         When ``accumulate_grad_batches`` > 1, the loss returned here will be automatically
         normalized by ``accumulate_grad_batches`` internally.



   .. py:method:: on_train_epoch_end(*args, **kwargs)

      Called in the training loop at the very end of the epoch.

      To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the
      :class:`~lightning.pytorch.LightningModule` and access them in this hook:

      .. code-block:: python

          class MyLightningModule(L.LightningModule):
              def __init__(self):
                  super().__init__()
                  self.training_step_outputs = []

              def training_step(self):
                  loss = ...
                  self.training_step_outputs.append(loss)
                  return loss

              def on_train_epoch_end(self):
                  # do something with all training_step outputs, for example:
                  epoch_mean = torch.stack(self.training_step_outputs).mean()
                  self.log("training_epoch_mean", epoch_mean)
                  # free up the memory
                  self.training_step_outputs.clear()




   .. py:method:: validation_step(batch, batch_idx, dataloader_idx=None, *args, **kwargs)

      Operates on a single batch of data from the validation set. In this step you'd might generate examples or
      calculate anything of interest like accuracy.

      :param batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`.
      :param batch_idx: The index of this batch.
      :param dataloader_idx: The index of the dataloader that produced this batch.
                             (only if multiple dataloaders used)

      :returns:

                - :class:`~torch.Tensor` - The loss tensor
                - ``dict`` - A dictionary. Can include any keys, but must include the key ``'loss'``.
                - ``None`` - Skip to the next batch.

      .. code-block:: python

          # if you have one val dataloader:
          def validation_step(self, batch, batch_idx): ...


          # if you have multiple val dataloaders:
          def validation_step(self, batch, batch_idx, dataloader_idx=0): ...

      Examples::

          # CASE 1: A single validation dataset
          def validation_step(self, batch, batch_idx):
              x, y = batch

              # implement your own
              out = self(x)
              loss = self.loss(out, y)

              # log 6 example images
              # or generated text... or whatever
              sample_imgs = x[:6]
              grid = torchvision.utils.make_grid(sample_imgs)
              self.logger.experiment.add_image('example_images', grid, 0)

              # calculate acc
              labels_hat = torch.argmax(out, dim=1)
              val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

              # log the outputs!
              self.log_dict({'val_loss': loss, 'val_acc': val_acc})

      If you pass in multiple val dataloaders, :meth:`validation_step` will have an additional argument. We recommend
      setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

      .. code-block:: python

          # CASE 2: multiple validation dataloaders
          def validation_step(self, batch, batch_idx, dataloader_idx=0):
              # dataloader_idx tells you which dataset this is.
              ...

      .. note:: If you don't need to validate you don't need to implement this method.

      .. note::

         When the :meth:`validation_step` is called, the model has been put in eval mode
         and PyTorch gradients have been disabled. At the end of validation,
         the model goes back to training mode and gradients are enabled.



   .. py:method:: on_validation_epoch_end(*args, **kwargs)

      Called in the validation loop at the very end of the epoch.



   .. py:method:: test_step(batch, batch_idx, dataloader_idx=None, *args, **kwargs)

      Operates on a single batch of data from the test set. In this step you'd normally generate examples or
      calculate anything of interest such as accuracy.

      :param batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`.
      :param batch_idx: The index of this batch.
      :param dataloader_idx: The index of the dataloader that produced this batch.
                             (only if multiple dataloaders used)

      :returns:

                - :class:`~torch.Tensor` - The loss tensor
                - ``dict`` - A dictionary. Can include any keys, but must include the key ``'loss'``.
                - ``None`` - Skip to the next batch.

      .. code-block:: python

          # if you have one test dataloader:
          def test_step(self, batch, batch_idx): ...


          # if you have multiple test dataloaders:
          def test_step(self, batch, batch_idx, dataloader_idx=0): ...

      Examples::

          # CASE 1: A single test dataset
          def test_step(self, batch, batch_idx):
              x, y = batch

              # implement your own
              out = self(x)
              loss = self.loss(out, y)

              # log 6 example images
              # or generated text... or whatever
              sample_imgs = x[:6]
              grid = torchvision.utils.make_grid(sample_imgs)
              self.logger.experiment.add_image('example_images', grid, 0)

              # calculate acc
              labels_hat = torch.argmax(out, dim=1)
              test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

              # log the outputs!
              self.log_dict({'test_loss': loss, 'test_acc': test_acc})

      If you pass in multiple test dataloaders, :meth:`test_step` will have an additional argument. We recommend
      setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

      .. code-block:: python

          # CASE 2: multiple test dataloaders
          def test_step(self, batch, batch_idx, dataloader_idx=0):
              # dataloader_idx tells you which dataset this is.
              ...

      .. note:: If you don't need to test you don't need to implement this method.

      .. note::

         When the :meth:`test_step` is called, the model has been put in eval mode and
         PyTorch gradients have been disabled. At the end of the test epoch, the model goes back
         to training mode and gradients are enabled.



   .. py:method:: on_test_epoch_end(*args, **kwargs)

      Called in the test loop at the very end of the epoch.



   .. py:method:: rm_and_create(path)
      :staticmethod:



   .. py:method:: on_predict_start() -> None

      Called at the beginning of predicting.



   .. py:method:: predict_step(*args, **kwargs)

      Step function called during :meth:`~lightning.pytorch.trainer.trainer.Trainer.predict`. By default, it calls
      :meth:`~lightning.pytorch.core.LightningModule.forward`. Override to add any processing logic.

      The :meth:`~lightning.pytorch.core.LightningModule.predict_step` is used
      to scale inference on multi-devices.

      To prevent an OOM error, it is possible to use :class:`~lightning.pytorch.callbacks.BasePredictionWriter`
      callback to write the predictions to disk or database after each batch or on epoch end.

      The :class:`~lightning.pytorch.callbacks.BasePredictionWriter` should be used while using a spawn
      based accelerator. This happens for ``Trainer(strategy="ddp_spawn")``
      or training on 8 TPU cores with ``Trainer(accelerator="tpu", devices=8)`` as predictions won't be returned.

      :param batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`.
      :param batch_idx: The index of this batch.
      :param dataloader_idx: The index of the dataloader that produced this batch.
                             (only if multiple dataloaders used)

      :returns: Predicted output (optional).

      Example ::

          class MyModel(LightningModule):

              def predict_step(self, batch, batch_idx, dataloader_idx=0):
                  return self(batch)

          dm = ...
          model = MyModel()
          trainer = Trainer(accelerator="gpu", devices=2)
          predictions = trainer.predict(model, dm)




   .. py:method:: on_predict_end() -> None

      Called at the end of predicting.



