Source code for trustlens.visualization

"""
trustlens.visualization.
========================
Visualization sub-package for TrustLens reports.

All plotting functions follow a consistent interface:
 * Accept pre-computed metric data (never raw model/data)
 * Return matplotlib Figure objects (for integration flexibility)
 * Accept optional ``save_path`` to write PNG files
 * Default to a clean, publication-quality style

The ``plot_module()`` dispatcher routes data to the appropriate plotter.
"""

from __future__ import annotations

import os
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np

from trustlens.visualization.bias_plots import plot_class_distribution
from trustlens.visualization.calibration_plots import plot_reliability_diagram
from trustlens.visualization.failure_plots import plot_confidence_gap
from trustlens.visualization.fairness import (
    _safe_name,
    plot_equalized_odds,
    plot_equalized_odds_multi,
    plot_fairness_gap,
    plot_fairness_gap_multi,
    plot_subgroup_performance,
    plot_subgroup_performance_multi,
)
from trustlens.visualization.representation_plots import (
    plot_embedding_2d,
    plot_embedding_separability,
)

__all__ = [
    "plot_reliability_diagram",
    "plot_confidence_gap",
    "plot_class_distribution",
    "plot_embedding_2d",
    "plot_embedding_separability",
    "plot_module",
    "plot_subgroup_performance",
    "plot_subgroup_performance_multi",
    "plot_equalized_odds",
    "plot_equalized_odds_multi",
    "plot_fairness_gap",
    "plot_fairness_gap_multi",
]

# ---------------------------------------------------------------------------
# Bias plot-type registry — deterministic ordering
# ---------------------------------------------------------------------------
_BIAS_PLOT_TYPES = (
    ("subgroup", plot_subgroup_performance_multi, "subgroup_performance"),
    ("equalized_odds", plot_equalized_odds_multi, "equalized_odds"),
    ("fairness_gap", plot_fairness_gap_multi, "equalized_odds"),
)


[docs] def plot_module( module_name: str, data: dict, save_dir: Optional[str] = None, *, embeddings: Optional[np.ndarray] = None, labels: Optional[np.ndarray] = None, ) -> None: """ Dispatch a module's result data to the appropriate visualization function. Parameters ---------- module_name : str Name of the analysis module (e.g., ``"calibration"``). data : dict Module result data from TrustReport.results[module_name]. save_dir : str, optional Directory to save the resulting PNG file(s). embeddings : np.ndarray, optional Embedding matrix (only used by ``"representation"`` module). labels : np.ndarray, optional Ground-truth labels (only used by ``"representation"`` module). """ if module_name == "representation": result = _plot_representation(data, embeddings=embeddings, labels=labels) elif module_name == "calibration": result = _plot_calibration(data) elif module_name == "failure": result = _plot_failure(data) elif module_name == "bias": result = _plot_bias(data) else: return if result is None: return # Short-circuit empty dict if isinstance(result, dict) and not result: return # Ensure output directory exists if save_dir: os.makedirs(save_dir, exist_ok=True) if isinstance(result, dict): for key, value in result.items(): if isinstance(value, dict): # Nested: dict[str, dict[str, Figure]] for subkey, subfig in value.items(): if subfig is not None: if save_dir: path = os.path.join( save_dir, f"{module_name}_{key}_{_safe_name(subkey)}.png", ) subfig.savefig(path, dpi=150, bbox_inches="tight") plt.close(subfig) else: # Flat: dict[str, Figure] if value is not None: if save_dir: path = os.path.join( save_dir, f"{module_name}_{key}.png", ) value.savefig(path, dpi=150, bbox_inches="tight") plt.close(value) else: # Single Figure (existing behaviour) if save_dir: save_path = os.path.join(save_dir, f"{module_name}_plot.png") result.savefig(save_path, dpi=150, bbox_inches="tight") plt.close(result)
def _plot_calibration(data: dict): if "reliability_curve" not in data: return None frac_pos, mean_pred, counts = data["reliability_curve"] return plot_reliability_diagram( frac_pos, mean_pred, ece=data.get("ece"), brier_score=data.get("brier_score"), ) def _plot_failure(data: dict): if "confidence_gap" not in data: return None return plot_confidence_gap(data["confidence_gap"]) def _plot_bias(data: dict): """Route bias data to the appropriate fairness visualizations. .. note:: Internal use only. Called by ``plot_module()``. Returns a single ``Figure`` for class-imbalance data, or a nested ``dict[str, dict[str, Figure]]`` keyed by plot type then feature when fairness metrics are present. File saving is handled exclusively by ``plot_module()``. """ if "class_imbalance" in data: return plot_class_distribution(data["class_imbalance"]) result = {} for key, multi_fn, data_key in _BIAS_PLOT_TYPES: if data_key in data: figures = multi_fn(data[data_key], save_dir=None, show=False) if figures: result[key] = figures return result if result else None def _plot_representation(data: dict, *, embeddings=None, labels=None): if "separability" not in data: return None fig_scorecard = plot_embedding_separability(data["separability"]) if embeddings is not None and labels is not None: sil = data["separability"].get("silhouette_score") fig_2d = plot_embedding_2d( embeddings=embeddings, labels=labels, silhouette_score=sil, show=False, ) return {"separability": fig_scorecard, "embedding_2d": fig_2d} return fig_scorecard