Quickstart¶
This page walks through the flagship example: run a parameter sweep and get the
results back as a labeled xarray.Dataset.
The full runnable script is at examples/sweep_to_dataset.py in the repository.
Define the workflow¶
Subclass MultiRunMetricsWorkflow and implement a static task method.
Whatever the method returns becomes a data variable in the output dataset —
no callbacks, no logging framework, just a plain dict.
import torch as tr
from mushin import multirun
from mushin.workflows import MultiRunMetricsWorkflow
LEARNING_RATES = [0.01, 0.1, 1.0]
SEEDS = [0, 1, 2]
POINTS_PER_CLASS = 256
def _make_data(seed: int, n: int = POINTS_PER_CLASS) -> tuple[tr.Tensor, tr.Tensor]:
g = tr.Generator().manual_seed(seed)
x0 = tr.randn(n, 2, generator=g) + tr.tensor([2.0, 2.0])
x1 = tr.randn(n, 2, generator=g) + tr.tensor([-2.0, -2.0])
x = tr.cat([x0, x1])
y = tr.cat([tr.zeros(n), tr.ones(n)])
return x, y
class LRSweep(MultiRunMetricsWorkflow):
@staticmethod
def task(lr: float, seed: int) -> dict:
tr.manual_seed(seed)
x, y = _make_data(seed)
model = tr.nn.Linear(2, 1)
opt = tr.optim.SGD(model.parameters(), lr=lr)
for _ in range(100):
opt.zero_grad()
logits = model(x).squeeze(1)
loss = tr.nn.functional.binary_cross_entropy_with_logits(logits, y)
loss.backward()
opt.step()
with tr.no_grad():
preds = (model(x).squeeze(1) > 0).float()
acc = (preds == y).float().mean().item()
# returning the dict is what populates the dataset; saving is optional
result = dict(accuracy=acc)
tr.save(result, "metrics.pt")
return result
Run the sweep¶
Call wf.run(...) with multirun(...) wrapped arguments. Hydra launches one
job per combination — 3 learning rates × 3 seeds = 9 runs total.
Get results as a dataset¶
Expected output:
<xarray.Dataset>
Dimensions: (lr: 3, seed: 3)
Coordinates:
* lr (lr) float64 0.01 0.1 1.0
* seed (seed) int64 0 1 2
Data variables:
accuracy (lr, seed) float64 ...
From there, standard xarray/pandas operations apply:
# average accuracy across seeds, per learning rate
mean_acc = ds["accuracy"].mean("seed")
print(mean_acc)
# plot
import matplotlib.pyplot as plt
mean_acc.plot.line(x="lr", marker="o")
plt.xscale("log")
plt.savefig("sweep_accuracy.png", dpi=120, bbox_inches="tight")
Run the full example¶
Next steps¶
- Workflows & sweeps guide — more on
BaseWorkflowandMultiRunMetricsWorkflow - Comparing methods guide — evaluate trained models with statistics
- Studies guide — combine training + compare in one call
- API Reference — workflows