Source code for trustlens.metrics.representation

"""
trustlens.metrics.representation.
=================================
Representation space analysis.

Probes the geometry of learned embedding spaces to understand:
* Whether classes are well-separated
* How similar two representation layers are (CKA)
* Whether cluster structure aligns with ground-truth labels

Metrics implemented
-------------------
* ``embedding_separability``  — silhouette score + within/between class distance
* ``centered_kernel_alignment`` — measures representational similarity between
  two sets of embeddings (e.g., two layers)

References
----------
* Kornblith, S., et al. (2019). Similarity of Neural Network Representations
  Revisited. ICML.
* Rousseeuw, P. (1987). Silhouettes: A graphical aid to the interpretation
  and validation of cluster analysis. Journal of Computational and Applied
  Mathematics.
"""

from __future__ import annotations

from typing import cast

import numpy as np
from sklearn.metrics import silhouette_score


[docs] def embedding_separability( embeddings: np.ndarray, y_true: np.ndarray, metric: str = "euclidean", sample_limit: int = 5000, ) -> dict: """ Measure how well class embeddings are separated in latent space. Uses the silhouette score as the primary separability measure, augmented with within-class and between-class mean distances. Parameters ---------- embeddings : np.ndarray Latent representations, shape (n_samples, embedding_dim). y_true : np.ndarray Ground-truth labels, shape (n_samples,). metric : str Distance metric passed to ``silhouette_score``. Default ``"euclidean"``. sample_limit : int Maximum samples used for silhouette computation (avoids O(n²) cost). A random subsample is drawn when ``len(embeddings) > sample_limit``. Returns ------- dict with keys: * ``silhouette_score`` — in [-1, 1]; 1.0 = perfect separation * ``within_class_distance`` — mean pairwise distance within classes * ``between_class_distance`` — mean pairwise distance across classes * ``separability_ratio`` — between / within (> 1 preferred) Examples -------- >>> sep = embedding_separability(embeddings, y_true) >>> print(f"Silhouette: {sep['silhouette_score']:.3f}") """ embeddings = np.asarray(embeddings, dtype=float) y_true = np.asarray(y_true) n = len(embeddings) # Subsample for large datasets if n > sample_limit: rng = np.random.default_rng(42) idx = rng.choice(n, sample_limit, replace=False) embeddings_ss = embeddings[idx] y_true_ss = y_true[idx] else: embeddings_ss = embeddings y_true_ss = y_true # Silhouette score requires at least 2 distinct labels n_classes = len(np.unique(y_true_ss)) if n_classes < 2: sil = float("nan") else: sil = float(silhouette_score(embeddings_ss, y_true_ss, metric=metric)) # Within-class and between-class distances (sampled) within_dists: list = [] between_dists: list = [] classes = np.unique(y_true_ss) # Limit pair-wise computation to a smaller random subset for speed max_pairs = 200 rng = np.random.default_rng(0) for cls in classes: in_cls = embeddings_ss[y_true_ss == cls] out_cls = embeddings_ss[y_true_ss != cls] if len(in_cls) >= 2: pairs = min(max_pairs, len(in_cls) * (len(in_cls) - 1) // 2) idx_a = rng.integers(0, len(in_cls), pairs) idx_b = rng.integers(0, len(in_cls), pairs) diff = in_cls[idx_a] - in_cls[idx_b] within_dists.extend(np.linalg.norm(diff, axis=1).tolist()) if len(in_cls) >= 1 and len(out_cls) >= 1: pairs = min(max_pairs, len(in_cls) * len(out_cls)) idx_a = rng.integers(0, len(in_cls), pairs) idx_b = rng.integers(0, len(out_cls), pairs) diff = in_cls[idx_a] - out_cls[idx_b] between_dists.extend(np.linalg.norm(diff, axis=1).tolist()) within_mean = float(np.mean(within_dists)) if within_dists else 0.0 between_mean = float(np.mean(between_dists)) if between_dists else 0.0 sep_ratio = round(between_mean / within_mean, 4) if within_mean > 0 else float("inf") return { "silhouette_score": round(sil, 4), "within_class_distance": round(within_mean, 4), "between_class_distance": round(between_mean, 4), "separability_ratio": sep_ratio, "n_samples_used": len(embeddings_ss), "embedding_dim": embeddings.shape[1], }
[docs] def centered_kernel_alignment( X: np.ndarray, Y: np.ndarray, ) -> float: r""" Compute Centered Kernel Alignment (CKA) between two representation matrices. CKA is a representational similarity metric that is invariant to orthogonal transformations and isotropic scaling, making it suitable for comparing representations across architectures and layers. .. math:: \\text{CKA}(K, L) = \\frac{\\text{HSIC}(K, L)}{ \\sqrt{\\text{HSIC}(K, K) \\cdot \\text{HSIC}(L, L)}} Parameters ---------- X : np.ndarray First representation matrix, shape (n_samples, d1). Y : np.ndarray Second representation matrix, shape (n_samples, d2). Returns ------- float CKA similarity score in [0, 1]. Higher → more similar representations. Raises ------ ValueError If ``X`` and ``Y`` have different numbers of samples. Examples -------- >>> cka = centered_kernel_alignment(layer1_embeddings, layer2_embeddings) >>> print(f"CKA similarity: {cka:.3f}") """ X = np.asarray(X, dtype=float) Y = np.asarray(Y, dtype=float) if X.shape[0] != Y.shape[0]: raise ValueError( f"X and Y must have the same number of samples, got {X.shape[0]} and {Y.shape[0]}." ) # Linear kernel matrices K = X @ X.T L = Y @ Y.T # Center the kernel matrices K = _center_kernel(K) L = _center_kernel(L) hsic_kl = _hsic(K, L) hsic_kk = _hsic(K, K) hsic_ll = _hsic(L, L) denom = np.sqrt(hsic_kk * hsic_ll) if denom < 1e-12: return 0.0 return float(np.clip(hsic_kl / denom, 0.0, 1.0))
def _center_kernel(K: np.ndarray) -> np.ndarray: """Double-center a kernel matrix.""" n = K.shape[0] H = cast(np.ndarray, np.eye(n) - np.ones((n, n)) / n) return cast(np.ndarray, H @ K @ H) def _hsic(K: np.ndarray, L: np.ndarray) -> float: """Biased HSIC estimator.""" n = K.shape[0] return float(cast(float, np.trace(K @ L)) / ((n - 1) ** 2))