Studies¶
Study combines a multi-seed training sweep with compare into a single call:
define your training functions once, specify seeds and evaluation data, and get
back a BenchmarkResult — no intermediate bookkeeping, no manual checkpoint
management.
Full motion: train + compare¶
The example below trains a CNN and MLP across seeds using Study, writing
checkpoints to disk and comparing the loaded models:
Defining train functions¶
def _make_cnn() -> nn.Module:
return nn.Sequential(
nn.Conv2d(1, 8, 3, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d(4),
nn.Flatten(),
nn.Linear(8 * 4 * 4, 10),
)
def _make_mlp() -> nn.Module:
return nn.Sequential(
nn.Flatten(),
nn.Linear(28 * 28, 64),
nn.ReLU(),
nn.Linear(64, 10),
)
def _train_and_save(model: nn.Module, loader: DataLoader, path: Path) -> str:
"""Train model for one epoch and save checkpoint; return its path."""
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
model.train()
for x, y in loader:
opt.zero_grad()
nn.functional.cross_entropy(model(x), y).backward()
opt.step()
model.eval()
torch.save(model, path)
return str(path)
def make_train_fn(name: str, model_factory, loader: DataLoader, ckpt_dir: Path):
"""Return a train_fn(seed) -> checkpoint_path for Study."""
def train_fn(seed: int) -> str:
torch.manual_seed(seed)
model = model_factory()
path = ckpt_dir / f"{name}_seed{seed}.pt"
return _train_and_save(model, loader, path)
return train_fn
Each train_fn(seed: int) -> str trains a model for the given seed, saves it,
and returns the checkpoint path. Study calls every train_fn for every seed
and stores the resulting paths.
Running the study¶
def run(
train_loader: DataLoader,
test_loader: DataLoader,
*,
seeds=(0, 1, 2),
working_dir: str | os.PathLike[str],
) -> BenchmarkResult:
"""Train CNN and MLP across seeds on ``train_loader``, then compare them on
the held-out ``test_loader``."""
# Resolve to an absolute path: each train_fn runs inside Hydra's per-job
# working directory, so a relative checkpoint dir would not point at the
# directory we create here.
work = Path(working_dir).resolve()
ckpt_dir = work / "checkpoints"
ckpt_dir.mkdir(parents=True, exist_ok=True)
methods = {
"cnn": make_train_fn("cnn", _make_cnn, train_loader, ckpt_dir),
"mlp": make_train_fn("mlp", _make_mlp, train_loader, ckpt_dir),
}
study = Study(
methods=methods,
load_fn=lambda p: torch.load(p, weights_only=False),
seeds=list(seeds),
data=test_loader,
num_classes=10,
test="welch",
working_dir=str(work),
)
return study.run()
Study parameters:
| Parameter | Description |
|---|---|
methods |
Dict mapping method name to a train_fn(seed: int) -> str (returns checkpoint path). |
load_fn |
Callable that loads a checkpoint path into a torch.nn.Module. |
seeds |
List of integer seeds to train each method on. |
data |
Re-iterable data loader for evaluation. |
num_classes |
Number of classes (required — Study evaluates with the default battery for its task). |
task |
"classification" (default) or "segmentation". |
test |
Statistical test: "welch", "wilcoxon", "mannwhitney", etc. |
alpha |
Significance threshold (default 0.05). |
ignore_index |
For segmentation: label to exclude (e.g. void class). |
working_dir |
Directory for Hydra sweep outputs (default: current directory). |
After study.run(), the checkpoint paths are stored at study.checkpoints
(dict[str, list[str]]) and study.working_dir records the resolved directory.
Annotated output¶
result.summary()
# method | metric | mean | ci_low | ci_high | significant_vs_ref
# cnn | accuracy | 0.963 | 0.951 | 0.975 |
# mlp | accuracy | 0.941 | 0.928 | 0.954 | *
result.data
# xarray.Dataset, dims (method, seed): one variable per metric
result.comparisons
# tidy DataFrame: pairwise p-values, effect sizes, Holm-corrected significance
"*" in significant_vs_ref means the method differs significantly from the
reference (the first method listed) after Holm correction.
Eval-only: from_checkpoints¶
Already have checkpoints? Skip training entirely:
Study.from_checkpoints(
checkpoints={
"cnn": ["cnn_seed0.pt", "cnn_seed1.pt", "cnn_seed2.pt"],
"mlp": ["mlp_seed0.pt", "mlp_seed1.pt", "mlp_seed2.pt"],
},
load_fn=lambda p: torch.load(p, weights_only=False),
data=val_loader,
num_classes=10,
test="welch",
).run()
Study.from_checkpoints takes the same evaluation parameters as Study.__init__
but accepts a pre-built checkpoints dict instead of methods and seeds.
Pitfalls
- train_fn must return a path: If it returns
None,Studyraises aValueError. Always return the saved checkpoint path. - Re-iterable data:
datamust be aDataLoader, not a one-shot iterator — it is evaluated once per model. - working_dir and Hydra:
Studyruns a Hydra sweep internally; if Hydra's working-directory change behavior conflicts with your setup, pass an explicitworking_dir.
See also¶
- Comparing methods guide — details on statistical tests and results
- API Reference — study