Skip to content

Studies

Study combines a multi-seed training sweep with compare into a single call: define your training functions once, specify seeds and evaluation data, and get back a BenchmarkResult — no intermediate bookkeeping, no manual checkpoint management.

Full motion: train + compare

The example below trains a CNN and MLP across seeds using Study, writing checkpoints to disk and comparing the loaded models:

Defining train functions

def _make_cnn() -> nn.Module:
    return nn.Sequential(
        nn.Conv2d(1, 8, 3, padding=1),
        nn.ReLU(),
        nn.AdaptiveAvgPool2d(4),
        nn.Flatten(),
        nn.Linear(8 * 4 * 4, 10),
    )


def _make_mlp() -> nn.Module:
    return nn.Sequential(
        nn.Flatten(),
        nn.Linear(28 * 28, 64),
        nn.ReLU(),
        nn.Linear(64, 10),
    )


def _train_and_save(model: nn.Module, loader: DataLoader, path: Path) -> str:
    """Train model for one epoch and save checkpoint; return its path."""
    opt = torch.optim.Adam(model.parameters(), lr=1e-3)
    model.train()
    for x, y in loader:
        opt.zero_grad()
        nn.functional.cross_entropy(model(x), y).backward()
        opt.step()
    model.eval()
    torch.save(model, path)
    return str(path)


def make_train_fn(name: str, model_factory, loader: DataLoader, ckpt_dir: Path):
    """Return a train_fn(seed) -> checkpoint_path for Study."""

    def train_fn(seed: int) -> str:
        torch.manual_seed(seed)
        model = model_factory()
        path = ckpt_dir / f"{name}_seed{seed}.pt"
        return _train_and_save(model, loader, path)

    return train_fn

Each train_fn(seed: int) -> str trains a model for the given seed, saves it, and returns the checkpoint path. Study calls every train_fn for every seed and stores the resulting paths.

Running the study

def run(
    train_loader: DataLoader,
    test_loader: DataLoader,
    *,
    seeds=(0, 1, 2),
    working_dir: str | os.PathLike[str],
) -> BenchmarkResult:
    """Train CNN and MLP across seeds on ``train_loader``, then compare them on
    the held-out ``test_loader``."""
    # Resolve to an absolute path: each train_fn runs inside Hydra's per-job
    # working directory, so a relative checkpoint dir would not point at the
    # directory we create here.
    work = Path(working_dir).resolve()
    ckpt_dir = work / "checkpoints"
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    methods = {
        "cnn": make_train_fn("cnn", _make_cnn, train_loader, ckpt_dir),
        "mlp": make_train_fn("mlp", _make_mlp, train_loader, ckpt_dir),
    }

    study = Study(
        methods=methods,
        load_fn=lambda p: torch.load(p, weights_only=False),
        seeds=list(seeds),
        data=test_loader,
        num_classes=10,
        test="welch",
        working_dir=str(work),
    )
    return study.run()

Study parameters:

Parameter Description
methods Dict mapping method name to a train_fn(seed: int) -> str (returns checkpoint path).
load_fn Callable that loads a checkpoint path into a torch.nn.Module.
seeds List of integer seeds to train each method on.
data Re-iterable data loader for evaluation.
num_classes Number of classes (required — Study evaluates with the default battery for its task).
task "classification" (default) or "segmentation".
test Statistical test: "welch", "wilcoxon", "mannwhitney", etc.
alpha Significance threshold (default 0.05).
ignore_index For segmentation: label to exclude (e.g. void class).
working_dir Directory for Hydra sweep outputs (default: current directory).

After study.run(), the checkpoint paths are stored at study.checkpoints (dict[str, list[str]]) and study.working_dir records the resolved directory.

Annotated output

result.summary()
# method | metric    | mean   | ci_low | ci_high | significant_vs_ref
# cnn    | accuracy  | 0.963  | 0.951  | 0.975   |
# mlp    | accuracy  | 0.941  | 0.928  | 0.954   | *

result.data
# xarray.Dataset, dims (method, seed): one variable per metric
result.comparisons
# tidy DataFrame: pairwise p-values, effect sizes, Holm-corrected significance

"*" in significant_vs_ref means the method differs significantly from the reference (the first method listed) after Holm correction.

Eval-only: from_checkpoints

Already have checkpoints? Skip training entirely:

Study.from_checkpoints(
    checkpoints={
        "cnn": ["cnn_seed0.pt", "cnn_seed1.pt", "cnn_seed2.pt"],
        "mlp": ["mlp_seed0.pt", "mlp_seed1.pt", "mlp_seed2.pt"],
    },
    load_fn=lambda p: torch.load(p, weights_only=False),
    data=val_loader,
    num_classes=10,
    test="welch",
).run()

Study.from_checkpoints takes the same evaluation parameters as Study.__init__ but accepts a pre-built checkpoints dict instead of methods and seeds.

Pitfalls

  • train_fn must return a path: If it returns None, Study raises a ValueError. Always return the saved checkpoint path.
  • Re-iterable data: data must be a DataLoader, not a one-shot iterator — it is evaluated once per model.
  • working_dir and Hydra: Study runs a Hydra sweep internally; if Hydra's working-directory change behavior conflicts with your setup, pass an explicit working_dir.

See also