Skip to content

Workflows & sweeps

mushin workflows are declarative wrappers around Hydra multirun jobs. You define your experiment as a method, run it once with swept parameters, and mushin handles config logging, output directories, and assembling results into a labeled xarray.Dataset.

The mental model

A mushin workflow has three steps:

  1. Define — subclass MultiRunMetricsWorkflow and implement a task(...) method that returns a dict of metrics.
  2. Run — call .run(...) with multirun(...) wrapped arguments to launch a Hydra sweep.
  3. Collect — call .to_xarray() to get a labeled dataset keyed by swept dimensions.

Runnable example

The following example sweeps learning rates and seeds on a synthetic 2-class dataset:

class LRSweep(MultiRunMetricsWorkflow):
    @staticmethod
    def task(lr: float, seed: int) -> dict:
        tr.manual_seed(seed)
        x, y = _make_data(seed)
        model = tr.nn.Linear(2, 1)
        opt = tr.optim.SGD(model.parameters(), lr=lr)
        for _ in range(100):
            opt.zero_grad()
            logits = model(x).squeeze(1)
            loss = tr.nn.functional.binary_cross_entropy_with_logits(logits, y)
            loss.backward()
            opt.step()
        with tr.no_grad():
            preds = (model(x).squeeze(1) > 0).float()
            acc = (preds == y).float().mean().item()
        # returning the dict is what populates the dataset; saving is optional
        result = dict(accuracy=acc)
        tr.save(result, "metrics.pt")
        return result


def build_dataset(working_dir: Path | None = None):
    """Run the learning-rate x seed sweep and return an ``xarray.Dataset``."""
    wf = LRSweep()
    wf.run(
        lr=multirun(LEARNING_RATES),
        seed=multirun(SEEDS),
        working_dir=str(working_dir) if working_dir is not None else None,
    )
    return wf.to_xarray()

The dict returned from task becomes the data variables in the output dataset. Any kwargs passed to .run(...) that are not wrapped in multirun(...) are treated as fixed overrides for every run.

Getting results

ds = wf.to_xarray()
# <xarray.Dataset> Dimensions: (lr: 3, seed: 3)
#   Data variables: accuracy (lr, seed)

ds["accuracy"].mean("seed")     # average over seeds, per learning rate
ds.sel(lr=0.1)                  # slice to a single lr

You can also save and reload the dataset as NetCDF (requires the netcdf extra):

ds.to_netcdf("results.nc")

import xarray as xr
ds = xr.open_dataset("results.nc")

BaseWorkflow

BaseWorkflow is the base class for all mushin workflows. It orchestrates Hydra jobs and exposes the raw results via .cfgs, .metrics, and .jobs attributes after .run(...) completes.

You rarely subclass BaseWorkflow directly — use MultiRunMetricsWorkflow instead, which adds the to_xarray() result aggregation layer.

RobustnessCurve

RobustnessCurve is a variant workflow for evaluating model robustness across perturbation strengths (e.g. noise levels, attack epsilons). It shares the same sweep-and-aggregate interface as MultiRunMetricsWorkflow.

See the API Reference — workflows for full parameter documentation.

hydra_list and multirun

from mushin import multirun, hydra_list
  • multirun(values) — wraps a list as a Hydra multirun override; Hydra creates one job per value.
  • hydra_list(values) — wraps a list as a single Hydra list override; all values are passed as a list to one job.

Pitfalls

  • task must return a dict: MultiRunMetricsWorkflow collects the returned dict as metrics. Returning None or a non-dict silently breaks to_xarray().
  • Fixed vs swept args: Only multirun(...)-wrapped args become dataset dimensions; fixed args are recorded in the Hydra config but not in the xarray dims.
  • Output directories: Hydra writes each job's output to a timestamped subdirectory. Pass working_dir=... to control the root.

See also