Skip to content

Tutorial

This tutorial walks you through the full mushin workflow end to end: define a sweep, collect a labeled dataset, compare methods with statistical significance, and interpret the result.

Step 1: Define a sweep

mushin workflows are subclasses of MultiRunMetricsWorkflow. You implement a task(...) method that returns a dict of metrics, then call .run(...) with swept parameters:

class LRSweep(MultiRunMetricsWorkflow):
    @staticmethod
    def task(lr: float, seed: int) -> dict:
        tr.manual_seed(seed)
        x, y = _make_data(seed)
        model = tr.nn.Linear(2, 1)
        opt = tr.optim.SGD(model.parameters(), lr=lr)
        for _ in range(100):
            opt.zero_grad()
            logits = model(x).squeeze(1)
            loss = tr.nn.functional.binary_cross_entropy_with_logits(logits, y)
            loss.backward()
            opt.step()
        with tr.no_grad():
            preds = (model(x).squeeze(1) > 0).float()
            acc = (preds == y).float().mean().item()
        # returning the dict is what populates the dataset; saving is optional
        result = dict(accuracy=acc)
        tr.save(result, "metrics.pt")
        return result


def build_dataset(working_dir: Path | None = None):
    """Run the learning-rate x seed sweep and return an ``xarray.Dataset``."""
    wf = LRSweep()
    wf.run(
        lr=multirun(LEARNING_RATES),
        seed=multirun(SEEDS),
        working_dir=str(working_dir) if working_dir is not None else None,
    )
    return wf.to_xarray()

The multirun(...) wrapper tells Hydra to create one job per value. Here the sweep creates 3 × 3 = 9 jobs (three learning rates × three seeds). Each job runs in its own output directory; the returned dict is collected automatically.

Step 2: Collect the dataset

After .run(...) completes, call .to_xarray() to get a labeled xarray.Dataset:

ds = wf.to_xarray()
# <xarray.Dataset> Dimensions: (lr: 3, seed: 3)
#   Data variables: accuracy (lr, seed)

ds["accuracy"].mean("seed")   # average accuracy per learning rate
ds.sel(lr=0.1)                # slice to a single lr

The dimensions come from the swept parameters; the data variables come from the dict your task returned.

Step 3: Compare methods with statistics

Once you have trained models (one list per method, one model per seed), pass them to compare:

def run(
    train_loader: DataLoader, test_loader: DataLoader, *, seeds=(0, 1, 2)
) -> BenchmarkResult:
    """Train one CNN and one MLP per seed, then compare them with statistics."""
    methods: dict[str, list[nn.Module]] = {"cnn": [], "mlp": []}
    for seed in seeds:
        torch.manual_seed(seed)
        methods["cnn"].append(_train(small_cnn(), train_loader))
        methods["mlp"].append(_train(mlp(), train_loader))

    return compare(
        methods,
        data=test_loader,
        task="classification",
        num_classes=10,
        test="welch",
    )

compare evaluates every model on data, assembles an (method × seed) xarray Dataset of metrics, and runs pairwise Holm-corrected significance tests.

Step 4: Read the statistics

result.summary()
# method | metric   | mean  | ci_low | ci_high | significant_vs_ref
# cnn    | accuracy | 0.963 | 0.951  | 0.975   |
# mlp    | accuracy | 0.941 | 0.928  | 0.954   | *

result.data           # xarray.Dataset, dims (method, seed)
result.comparisons    # tidy DataFrame with p-values and effect sizes

"*" in significant_vs_ref means the method differs significantly from the reference (first method listed) after Holm correction at alpha=0.05.

Next steps