lightning_template.datasets.base
================================

.. py:module:: lightning_template.datasets.base


Classes
-------

.. autoapisummary::

   lightning_template.datasets.base.LightningDataModule


Module Contents
---------------

.. py:class:: LightningDataModule(dataset_cfg: dict = None, dataloader_cfg: dict = None)

   Bases: :py:obj:`lightning_template.utils.mixin.SplitNameMixin`, :py:obj:`lightning.pytorch.core.datamodule.LightningDataModule`


   A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is
   consistent data splits, data preparation and transforms across models.

   Example::

       import lightning as L
       import torch.utils.data as data
       from lightning.pytorch.demos.boring_classes import RandomDataset

       class MyDataModule(L.LightningDataModule):
           def prepare_data(self):
               # download, IO, etc. Useful with shared filesystems
               # only called on 1 GPU/TPU in distributed
               ...

           def setup(self, stage):
               # make assignments here (val/train/test split)
               # called on every process in DDP
               dataset = RandomDataset(1, 100)
               self.train, self.val, self.test = data.random_split(
                   dataset, [80, 10, 10], generator=torch.Generator().manual_seed(42)
               )

           def train_dataloader(self):
               return data.DataLoader(self.train)

           def val_dataloader(self):
               return data.DataLoader(self.val)

           def test_dataloader(self):
               return data.DataLoader(self.test)

           def on_exception(self, exception):
               # clean up state after the trainer faced an exception
               ...

           def teardown(self):
               # clean up state after the trainer stops, delete files...
               # called on every process in DDP
               ...



   .. py:attribute:: datasets


   .. py:attribute:: dataset
      :value: None



   .. py:attribute:: num_folds
      :value: None



   .. py:attribute:: folds


   .. py:attribute:: splits
      :value: []



   .. py:attribute:: batch_size
      :value: None



   .. py:attribute:: dataset_cfg


   .. py:attribute:: dataloader_cfg


   .. py:method:: build_dataset(split)


   .. py:method:: build_collate_fn(collate_fn_cfg, dataset)


   .. py:method:: build_sampler(dataloader_cfg, dataset, split)


   .. py:method:: build_batch_sampler(batch_sampler_cfg, dataset, *args)


   .. py:method:: handle_dataloader_config(dataloader_cfg, dataset, split, *arg, **kwargs)


   .. py:method:: _build_dataloader(dataset, dataloader_cfg, split)


   .. py:method:: build_dataloader(split)


   .. py:method:: setup(stage=None)

      Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you
      need to build models dynamically or adjust something about them. This hook is called on every process when
      using DDP.

      :param stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'``

      Example::

          class LitModel(...):
              def __init__(self):
                  self.l1 = None

              def prepare_data(self):
                  download_data()
                  tokenize()

                  # don't do this
                  self.something = else

              def setup(self, stage):
                  data = load_data(...)
                  self.l1 = nn.Linear(28, data.num_classes)




   .. py:method:: setup_folds(num_folds: int) -> None


   .. py:method:: setup_fold_index(fold_index: int) -> None


   .. py:method:: train_dataloader()

      An iterable or collection of iterables specifying training samples.

      For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.

      The dataloader you return will not be reloaded unless you set
      :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to
      a positive integer.

      For data processing use the following pattern:

          - download in :meth:`prepare_data`
          - process and split in :meth:`setup`

      However, the above are only necessary for distributed processing.

      .. warning:: do not assign state in prepare_data

      - :meth:`~lightning.pytorch.trainer.trainer.Trainer.fit`
      - :meth:`prepare_data`
      - :meth:`setup`

      .. note::

         Lightning tries to add the correct sampler for distributed and arbitrary hardware.
         There is no need to set it yourself.



   .. py:method:: val_dataloader()

      An iterable or collection of iterables specifying validation samples.

      For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.

      The dataloader you return will not be reloaded unless you set
      :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to
      a positive integer.

      It's recommended that all data downloads and preparation happen in :meth:`prepare_data`.

      - :meth:`~lightning.pytorch.trainer.trainer.Trainer.fit`
      - :meth:`~lightning.pytorch.trainer.trainer.Trainer.validate`
      - :meth:`prepare_data`
      - :meth:`setup`

      .. note::

         Lightning tries to add the correct sampler for distributed and arbitrary hardware
         There is no need to set it yourself.

      .. note::

         If you don't need a validation dataset and a :meth:`validation_step`, you don't need to
         implement this method.



   .. py:method:: test_dataloader()

      An iterable or collection of iterables specifying test samples.

      For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.

      For data processing use the following pattern:

          - download in :meth:`prepare_data`
          - process and split in :meth:`setup`

      However, the above are only necessary for distributed processing.

      .. warning:: do not assign state in prepare_data


      - :meth:`~lightning.pytorch.trainer.trainer.Trainer.test`
      - :meth:`prepare_data`
      - :meth:`setup`

      .. note::

         Lightning tries to add the correct sampler for distributed and arbitrary hardware.
         There is no need to set it yourself.

      .. note::

         If you don't need a test dataset and a :meth:`test_step`, you don't need to implement
         this method.



   .. py:method:: predict_dataloader()

      An iterable or collection of iterables specifying prediction samples.

      For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.

      It's recommended that all data downloads and preparation happen in :meth:`prepare_data`.

      - :meth:`~lightning.pytorch.trainer.trainer.Trainer.predict`
      - :meth:`prepare_data`
      - :meth:`setup`

      .. note::

         Lightning tries to add the correct sampler for distributed and arbitrary hardware
         There is no need to set it yourself.

      :returns: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying prediction samples.



