lightning¶
HydraDDP ¶
Bases: DDPPlugin
DDP Strategy that supports Hydra run and multirun jobs.
This strategy assumes a PyTorch Lightning Trainer.fit or Trainer.test has been configured
to execute via Hydra. It requires that Hydra saves a config.yaml in the current working directory with the following keys/properties set::
├── Config
│ ├── trainer: A pytorch_lightning.Trainer configuration
│ ├── module: A pytorch_lightning.LightningModule configuration
│ ├── datamodule: [OPTIONAL] A pytorch_lightning.LightningDataModule configuration
This strategy will launch a child subprocesses for additional GPU beyond the first using the following base command::
python -m mushin.lightning._pl_main -cp
Examples:
First define a Hydra configuration using hydra-zen:
>>> import pytorch_lightning as pl
... from hydra_zen import builds, make_config,
... from mushin import HydraDDP
... from mushin.testing.lightning import SimpleLightningModule
...
... TrainerConfig = builds(
... pl.Trainer,
... accelerator="auto",
... gpus=2,
... max_epochs=1,
... fast_dev_run=True,
... strategy=builds(HydraDDP),
... populate_full_signature=True
... )
...
... ModuleConfig = builds(SimpleLightningModule)
...
... Config = make_config(
... trainer=TrainerConfig,
... module=ModuleConfig
... )
Next define a task function to execute the Hydra job:
>>> from hydra_zen import instantiate
>>> def task_function(cfg):
... obj = instantiate(cfg)
... obj.trainer.fit(obj.module)
Launch the Hydra+Lightning DDP job:
HydraDDP also supports LightningDataModule configuration.
>>> DataModuleConfig = ... # A LightningDataModule config
>>> Config = make_config(
... trainer=TrainerConfig,
... module=ModuleConfig
... datamodule=DataModuleconfig
... )
Next, define a task function to execute the Hydra job:
>>> from hydra_zen import instantiate
>>> def task_function(cfg):
... obj = instantiate(cfg)
... obj.trainer.fit(obj.module, datamodule=obj.datamodule)
Launch the Hydra+Lightning DDP job:
MetricsCallback ¶
Bases: Callback
Saves validation and test metrics stored in trainer.callback_metrics.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
save_dir
|
(str, optional(default='.'))
|
|
'.'
|
filename
|
(str, optional(default=pt))
|
The base filename used to store metrics. For |
'metrics.pt'
|
Notes
No metrics will be saved during FITTING if no validation metrics are calculated.
This is a limitation of PyTorch Lightning. Future versions will save the training
step metrics when no validation metrics are calculated.
Examples: