Segmentation¶
compare and Study support semantic segmentation via task="segmentation".
Models receive (N, C, H, W) input tensors and must produce (N, num_classes, H, W)
logit tensors; the default predict_fn takes the argmax over classes and softmax
probabilities for you.
Runnable example¶
The example below compares two tiny segmentation models on synthetic pixel masks:
def run(
loader: DataLoader,
*,
in_channels: int = 3,
num_classes: int = 4,
seeds=(0, 1, 2),
) -> BenchmarkResult:
"""Train two tiny segmentation models per seed, then compare."""
methods: dict[str, list[nn.Module]] = {"model_a": [], "model_b": []}
for seed in seeds:
torch.manual_seed(seed)
methods["model_a"].append(
_train(_tiny_seg_model(in_channels, num_classes), loader)
)
torch.manual_seed(seed + 100)
methods["model_b"].append(
_train(_tiny_seg_model(in_channels, num_classes), loader)
)
return compare(
methods,
data=loader,
task="segmentation",
num_classes=num_classes,
test="welch",
)
The default segmentation battery includes:
| Metric | Notes |
|---|---|
miou |
Mean Intersection over Union (macro-averaged) |
dice |
Macro-averaged Dice coefficient (= macro F1) |
pixel_acc |
Micro-averaged pixel accuracy |
precision |
Macro-averaged per-class precision |
recall |
Macro-averaged per-class recall |
All are confusion-matrix based and computed via torchmetrics, so streaming evaluation uses O(C²) memory.
Ignoring void / boundary labels¶
Many segmentation datasets use a special label (e.g. 255 in PASCAL VOC) to
mark void or boundary pixels. Pass ignore_index to exclude these from all
metrics:
# fcn_models and deeplab_models are each a list of trained nn.Module (one per seed)
result = compare(
methods={"fcn": fcn_models, "deeplab": deeplab_models},
data=val_loader,
task="segmentation",
num_classes=21,
ignore_index=255,
)
Custom predict_fn for models that return dicts¶
Some models (e.g. torchvision.models.segmentation) return a dict instead of
a plain tensor. Use predict_fn to adapt the output:
def torchvision_seg_predict(model, x):
"""Adapt a torchvision segmentation model (returns {"out": logits})."""
logits = model(x)["out"]
probs = logits.softmax(dim=1)
return probs.argmax(dim=1), probs
Pass it to compare:
compare(
{"fcn": fcn_models, "deeplab": deeplab_models},
data=val_loader,
task="segmentation",
num_classes=21,
ignore_index=255,
predict_fn=torchvision_seg_predict,
)
The predict_fn signature is (model, batch_x) -> (predictions, probabilities),
where predictions is a (N, H, W) long tensor of class indices and
probabilities is a (N, C, H, W) float tensor of per-class probabilities.
predict_fn must always return a 2-tuple
predict_fn always returns (predictions, probabilities). If you have no
probabilities to provide, return the predictions twice
(return preds, preds). For task="segmentation", prob_metrics is
already empty, so the duplicate is never used.
Using Study for segmentation¶
from mushin import Study
study = Study(
methods={"fcn": train_fcn, "deeplab": train_deeplab},
load_fn=lambda p: torch.load(p, weights_only=False),
seeds=[0, 1, 2],
data=val_loader,
task="segmentation",
num_classes=21,
ignore_index=255,
)
result = study.run()
Pitfalls
- Input shape: Models must accept
(N, C, H, W)and return(N, num_classes, H, W)logits. A 1×1Conv2dis the minimal example. - ignore_index: Not supported by AUROC/ECE, but the segmentation
battery has neither —
ignore_indexworks correctly for all five segmentation metrics. - Dict-output models: Always wrap them with a
predict_fn; passing a dict to the defaultpredict_fnwill raise an error.
See also¶
- Comparing methods guide — statistical tests and result reading
- Custom metrics & predict_fn — override the metric battery
- API Reference — benchmark