Workflows & sweeps¶
mushin workflows are declarative wrappers around Hydra multirun jobs. You
define your experiment as a method, run it once with swept parameters, and
mushin handles config logging, output directories, and assembling results into
a labeled xarray.Dataset.
The mental model¶
A mushin workflow has three steps:
- Define — subclass
MultiRunMetricsWorkflowand implement atask(...)method that returns a dict of metrics. - Run — call
.run(...)withmultirun(...)wrapped arguments to launch a Hydra sweep. - Collect — call
.to_xarray()to get a labeled dataset keyed by swept dimensions.
Runnable example¶
The following example sweeps learning rates and seeds on a synthetic 2-class dataset:
class LRSweep(MultiRunMetricsWorkflow):
@staticmethod
def task(lr: float, seed: int) -> dict:
tr.manual_seed(seed)
x, y = _make_data(seed)
model = tr.nn.Linear(2, 1)
opt = tr.optim.SGD(model.parameters(), lr=lr)
for _ in range(100):
opt.zero_grad()
logits = model(x).squeeze(1)
loss = tr.nn.functional.binary_cross_entropy_with_logits(logits, y)
loss.backward()
opt.step()
with tr.no_grad():
preds = (model(x).squeeze(1) > 0).float()
acc = (preds == y).float().mean().item()
# returning the dict is what populates the dataset; saving is optional
result = dict(accuracy=acc)
tr.save(result, "metrics.pt")
return result
def build_dataset(working_dir: Path | None = None):
"""Run the learning-rate x seed sweep and return an ``xarray.Dataset``."""
wf = LRSweep()
wf.run(
lr=multirun(LEARNING_RATES),
seed=multirun(SEEDS),
working_dir=str(working_dir) if working_dir is not None else None,
)
return wf.to_xarray()
The dict returned from task becomes the data variables in the output dataset.
Any kwargs passed to .run(...) that are not wrapped in multirun(...) are
treated as fixed overrides for every run.
Getting results¶
ds = wf.to_xarray()
# <xarray.Dataset> Dimensions: (lr: 3, seed: 3)
# Data variables: accuracy (lr, seed)
ds["accuracy"].mean("seed") # average over seeds, per learning rate
ds.sel(lr=0.1) # slice to a single lr
You can also save and reload the dataset as NetCDF (requires the netcdf extra):
BaseWorkflow¶
BaseWorkflow is the base class for all mushin workflows. It orchestrates
Hydra jobs and exposes the raw results via .cfgs, .metrics, and .jobs
attributes after .run(...) completes.
You rarely subclass BaseWorkflow directly — use MultiRunMetricsWorkflow
instead, which adds the to_xarray() result aggregation layer.
RobustnessCurve¶
RobustnessCurve is a variant workflow for evaluating model robustness across
perturbation strengths (e.g. noise levels, attack epsilons). It shares the same
sweep-and-aggregate interface as MultiRunMetricsWorkflow.
See the API Reference — workflows for full parameter documentation.
hydra_list and multirun¶
multirun(values)— wraps a list as a Hydra multirun override; Hydra creates one job per value.hydra_list(values)— wraps a list as a single Hydra list override; all values are passed as a list to one job.
Pitfalls
- task must return a dict:
MultiRunMetricsWorkflowcollects the returned dict as metrics. ReturningNoneor a non-dict silently breaksto_xarray(). - Fixed vs swept args: Only
multirun(...)-wrapped args become dataset dimensions; fixed args are recorded in the Hydra config but not in the xarray dims. - Output directories: Hydra writes each job's output to a timestamped
subdirectory. Pass
working_dir=...to control the root.
See also¶
- Tutorial — end-to-end: sweep → dataset → compare
- API Reference — workflows