Source code for trustlens.metrics.failure

"""
trustlens.metrics.failure.
==========================
Failure-mode analysis: where and how does a model fail?

Metrics implemented
-------------------
* ``misclassification_summary`` — per-class error rates and high-confidence
  mistakes.
* ``confidence_gap``      — distribution of confidence for correct vs.
  incorrect predictions.
"""

from __future__ import annotations

import numpy as np


[docs] def misclassification_summary( y_true: np.ndarray, y_pred: np.ndarray, y_prob: np.ndarray, ) -> dict: """ Build a comprehensive misclassification summary. For each class, reports: * total support (ground truth count) * number of misclassified samples * error rate * average confidence of misclassified samples (overconfident mistakes) * indices of the *most confident* misclassifications Parameters ---------- y_true : np.ndarray Ground-truth labels, shape (n_samples,). y_pred : np.ndarray Model predictions, shape (n_samples,). y_prob : np.ndarray Predicted probabilities, shape (n_samples,) for binary or (n_samples, n_classes) for multi-class. Returns ------- dict Nested dictionary keyed by class label. Examples -------- >>> summary = misclassification_summary(y_true, y_pred, y_prob) >>> print(summary[1]["error_rate"]) # error rate for class 1 """ y_true = np.asarray(y_true) y_pred = np.asarray(y_pred) y_prob = np.asarray(y_prob) # Max probability across classes for each sample if y_prob.ndim == 1: max_conf = y_prob # binary: confidence in positive class else: max_conf = y_prob.max(axis=1) incorrect_mask = y_true != y_pred classes = np.unique(y_true) summary: dict = {} for cls in classes: cls_mask = y_true == int(cls) cls_incorrect = cls_mask & incorrect_mask n_support = int(cls_mask.sum()) n_misclassified = int(cls_incorrect.sum()) error_rate = n_misclassified / n_support if n_support > 0 else 0.0 miscls_confidences = max_conf[cls_incorrect] avg_misclassification_confidence = ( float(miscls_confidences.mean()) if len(miscls_confidences) > 0 else 0.0 ) # Indices of top-5 most confident mistakes (high-confidence errors) if len(miscls_confidences) > 0: topk = min(5, len(miscls_confidences)) top_mistake_indices = np.argsort(miscls_confidences)[-topk:][::-1].tolist() else: top_mistake_indices = [] summary[int(cls)] = { "support": n_support, "n_misclassified": n_misclassified, "error_rate": round(error_rate, 4), "avg_misclassification_confidence": round(avg_misclassification_confidence, 4), "top_mistake_indices": top_mistake_indices, } summary["__overall__"] = { "total_errors": int(incorrect_mask.sum()), "overall_error_rate": round(float(incorrect_mask.mean()), 4), } return summary
[docs] def confidence_gap( y_true: np.ndarray, y_pred: np.ndarray, y_prob: np.ndarray, n_bins: int = 20, ) -> dict: """ Measure the *confidence gap* — how much more confident is the model on correct predictions than on incorrect ones? Returns ------- dict with keys: * ``correct_confidence`` — confidence distribution for correct preds * ``incorrect_confidence`` — confidence distribution for incorrect preds * ``gap`` — mean(correct_conf) - mean(incorrect_conf) * ``histogram_bins`` — bin edges for the confidence histogram * ``correct_hist`` — histogram counts for correct predictions * ``incorrect_hist`` — histogram counts for incorrect predictions Examples -------- >>> gap_data = confidence_gap(y_true, y_pred, y_prob) >>> print(f"Confidence gap: {gap_data['gap']:.3f}") """ y_true = np.asarray(y_true) y_pred = np.asarray(y_pred) y_prob = np.asarray(y_prob) if y_prob.ndim == 1: max_conf = y_prob else: max_conf = y_prob.max(axis=1) correct_mask = y_true == y_pred correct_conf = max_conf[correct_mask] incorrect_conf = max_conf[~correct_mask] bins = np.linspace(0.0, 1.0, n_bins + 1) correct_hist, _ = np.histogram(correct_conf, bins=bins) incorrect_hist, _ = np.histogram(incorrect_conf, bins=bins) gap = float(correct_conf.mean() - incorrect_conf.mean()) if len(incorrect_conf) > 0 else 0.0 return { "correct_confidence_mean": float(correct_conf.mean()) if len(correct_conf) > 0 else 0.0, "incorrect_confidence_mean": float(incorrect_conf.mean()) if len(incorrect_conf) > 0 else 0.0, "gap": round(gap, 4), "histogram_bins": bins, "correct_hist": correct_hist, "incorrect_hist": incorrect_hist, "n_correct": int(correct_mask.sum()), "n_incorrect": int((~correct_mask).sum()), }