Skip to content

Custom metrics & predict_fn

mushin's metric batteries and prediction logic are fully replaceable. This guide shows how to extend mushin with custom metrics and how to adapt models that don't return plain tensors.

These are compare arguments

metrics, predict_fn, and prob_metrics are arguments to compare. Study always evaluates with the default battery for its task (classification or segmentation) — to use a custom battery or predict step with trained models, call compare directly.

Custom metrics dict

Pass a metrics dict to compare to replace the default battery:

from torchmetrics.classification import MulticlassF1Score, MulticlassAccuracy

# cnn_models is a list of trained nn.Module (one per seed)
compare(
    methods={"cnn": cnn_models},
    data=val_loader,
    task="classification",     # still sets the default predict_fn
    metrics={
        "accuracy": MulticlassAccuracy(num_classes=10, average="micro"),
        "f1_macro": MulticlassF1Score(num_classes=10, average="macro"),
        "f1_weighted": MulticlassF1Score(num_classes=10, average="weighted"),
    },
    # num_classes is not required when metrics is provided
)

Each value must be a torchmetrics.Metric instance. The keys become the data variable names in result.data.

prob_metrics

Some metrics require class probabilities rather than hard predictions (e.g. AUROC, ECE). mushin uses a prob_metrics frozenset to know which metrics to feed probabilities to. The default is the task's built-in set; override it when you add probability-based custom metrics:

from torchmetrics.classification import MulticlassAUROC, MulticlassAccuracy

# cnn_models is a list of trained nn.Module (one per seed)
compare(
    methods={"cnn": cnn_models},
    data=val_loader,
    task="classification",
    metrics={
        "accuracy": MulticlassAccuracy(num_classes=10, average="micro"),
        "auroc": MulticlassAUROC(num_classes=10),
    },
    prob_metrics=frozenset({"auroc"}),  # feed probabilities only to auroc
    num_classes=10,
)

Custom predict_fn

The default predict_fn calls the model, takes the argmax over class logits, and returns (predictions, softmax_probabilities). Replace it when your model returns something other than a plain (N, C) or (N, C, H, W) logit tensor.

Adapting torchvision segmentation models

torchvision segmentation models return a dict {"out": logits, ...}. Here is the adapter from the segmentation example:

def torchvision_seg_predict(model, x):
    """Adapt a torchvision segmentation model (returns {"out": logits})."""
    logits = model(x)["out"]
    probs = logits.softmax(dim=1)
    return probs.argmax(dim=1), probs

Pass it to compare:

# fcn_models is a list of trained nn.Module (one per seed)
compare(
    {"fcn": fcn_models},
    data=val_loader,
    task="segmentation",
    num_classes=21,
    ignore_index=255,
    predict_fn=torchvision_seg_predict,
)

predict_fn contract

The predict_fn signature is:

def predict_fn(model: nn.Module, x: Tensor) -> tuple[Tensor, Tensor]:
    ...
    return predictions, probabilities
  • predictions: long tensor of class indices — (N,) for classification, (N, H, W) for segmentation.
  • probabilities: float tensor of per-class probabilities — (N, C) for classification, (N, C, H, W) for segmentation.

If no probabilities are available, return predictions twice; the second element is only consumed by metrics listed in prob_metrics.

Pitfalls

  • Always return a 2-tuple: The evaluation loop always unpacks both elements. Returning only predictions will raise a ValueError.
  • prob_metrics mismatch: If you add a probability-based metric but forget to add its name to prob_metrics, it receives hard predictions and will likely error or silently produce wrong results.
  • Metric state: torchmetrics metrics are stateful; mushin calls .reset() before each model evaluation. Do not share metric instances across calls.

See also