Tutorial¶
This tutorial walks you through the full mushin workflow end to end: define a sweep, collect a labeled dataset, compare methods with statistical significance, and interpret the result.
Step 1: Define a sweep¶
mushin workflows are subclasses of MultiRunMetricsWorkflow. You implement a
task(...) method that returns a dict of metrics, then call .run(...) with
swept parameters:
class LRSweep(MultiRunMetricsWorkflow):
@staticmethod
def task(lr: float, seed: int) -> dict:
tr.manual_seed(seed)
x, y = _make_data(seed)
model = tr.nn.Linear(2, 1)
opt = tr.optim.SGD(model.parameters(), lr=lr)
for _ in range(100):
opt.zero_grad()
logits = model(x).squeeze(1)
loss = tr.nn.functional.binary_cross_entropy_with_logits(logits, y)
loss.backward()
opt.step()
with tr.no_grad():
preds = (model(x).squeeze(1) > 0).float()
acc = (preds == y).float().mean().item()
# returning the dict is what populates the dataset; saving is optional
result = dict(accuracy=acc)
tr.save(result, "metrics.pt")
return result
def build_dataset(working_dir: Path | None = None):
"""Run the learning-rate x seed sweep and return an ``xarray.Dataset``."""
wf = LRSweep()
wf.run(
lr=multirun(LEARNING_RATES),
seed=multirun(SEEDS),
working_dir=str(working_dir) if working_dir is not None else None,
)
return wf.to_xarray()
The multirun(...) wrapper tells Hydra to create one job per value. Here the
sweep creates 3 × 3 = 9 jobs (three learning rates × three seeds). Each job
runs in its own output directory; the returned dict is collected automatically.
Step 2: Collect the dataset¶
After .run(...) completes, call .to_xarray() to get a labeled xarray.Dataset:
ds = wf.to_xarray()
# <xarray.Dataset> Dimensions: (lr: 3, seed: 3)
# Data variables: accuracy (lr, seed)
ds["accuracy"].mean("seed") # average accuracy per learning rate
ds.sel(lr=0.1) # slice to a single lr
The dimensions come from the swept parameters; the data variables come from the
dict your task returned.
Step 3: Compare methods with statistics¶
Once you have trained models (one list per method, one model per seed), pass
them to compare:
def run(
train_loader: DataLoader, test_loader: DataLoader, *, seeds=(0, 1, 2)
) -> BenchmarkResult:
"""Train one CNN and one MLP per seed, then compare them with statistics."""
methods: dict[str, list[nn.Module]] = {"cnn": [], "mlp": []}
for seed in seeds:
torch.manual_seed(seed)
methods["cnn"].append(_train(small_cnn(), train_loader))
methods["mlp"].append(_train(mlp(), train_loader))
return compare(
methods,
data=test_loader,
task="classification",
num_classes=10,
test="welch",
)
compare evaluates every model on data, assembles an (method × seed)
xarray Dataset of metrics, and runs pairwise Holm-corrected significance tests.
Step 4: Read the statistics¶
result.summary()
# method | metric | mean | ci_low | ci_high | significant_vs_ref
# cnn | accuracy | 0.963 | 0.951 | 0.975 |
# mlp | accuracy | 0.941 | 0.928 | 0.954 | *
result.data # xarray.Dataset, dims (method, seed)
result.comparisons # tidy DataFrame with p-values and effect sizes
"*" in significant_vs_ref means the method differs significantly from the
reference (first method listed) after Holm correction at alpha=0.05.
Next steps¶
- Core concepts — the mental model behind mushin
- Comparing methods — deeper coverage of
compare - Studies — combine training and comparison in one call
- Understanding the statistics — which test to choose