Skip to content

lightning

HydraDDP

Bases: DDPPlugin

DDP Strategy that supports Hydra run and multirun jobs.

This strategy assumes a PyTorch Lightning Trainer.fit or Trainer.test has been configured to execute via Hydra. It requires that Hydra saves a config.yaml in the current working directory with the following keys/properties set::

├── Config │ ├── trainer: A pytorch_lightning.Trainer configuration │ ├── module: A pytorch_lightning.LightningModule configuration │ ├── datamodule: [OPTIONAL] A pytorch_lightning.LightningDataModule configuration

This strategy will launch a child subprocesses for additional GPU beyond the first using the following base command::

python -m mushin.lightning._pl_main -cp -cn config.yaml

Examples:

First define a Hydra configuration using hydra-zen:

>>> import pytorch_lightning as pl
... from hydra_zen import builds, make_config,
... from mushin import HydraDDP
... from mushin.testing.lightning import SimpleLightningModule
...
... TrainerConfig = builds(
...     pl.Trainer,
...     accelerator="auto",
...     gpus=2,
...     max_epochs=1,
...     fast_dev_run=True,
...     strategy=builds(HydraDDP),
...     populate_full_signature=True
... )
...
... ModuleConfig = builds(SimpleLightningModule)
...
... Config = make_config(
...     trainer=TrainerConfig,
...     module=ModuleConfig
... )

Next define a task function to execute the Hydra job:

>>> from hydra_zen import instantiate
>>> def task_function(cfg):
...     obj = instantiate(cfg)
...     obj.trainer.fit(obj.module)

Launch the Hydra+Lightning DDP job:

>>> from hydra_zen import launch
>>> job = launch(Config, task_function)

HydraDDP also supports LightningDataModule configuration.

>>> DataModuleConfig = ... # A LightningDataModule config
>>> Config = make_config(
...     trainer=TrainerConfig,
...     module=ModuleConfig
...     datamodule=DataModuleconfig
... )

Next, define a task function to execute the Hydra job:

>>> from hydra_zen import instantiate
>>> def task_function(cfg):
...     obj = instantiate(cfg)
...     obj.trainer.fit(obj.module, datamodule=obj.datamodule)

Launch the Hydra+Lightning DDP job:

>>> from hydra_zen import launch
>>> job = launch(Config, task_function)

teardown

teardown()

Performs additional teardown steps for PL to allow for Hydra multirun jobs.

MetricsCallback

Bases: Callback

Saves validation and test metrics stored in trainer.callback_metrics.

Parameters:

Name Type Description Default
save_dir (str, optional(default='.'))
'.'
filename (str, optional(default=pt))

The base filename used to store metrics. For FITTING the file is prepended with "fit_" and for TESTING the file is prepended with test_.

'metrics.pt'
Notes

No metrics will be saved during FITTING if no validation metrics are calculated. This is a limitation of PyTorch Lightning. Future versions will save the training step metrics when no validation metrics are calculated.

Examples:

>>> from pytorch_lightning import Trainer
>>> from mushin import MetricsCallback
>>> metrics_callback = MetricsCallback()
>>> trainer = Trainer(callbacks=[metrics_callback])