Source code for bias_amplification.metrics.CoOccurMetrics

# Importing Libraries
import torch
from abc import ABC, abstractmethod
import sys
from pathlib import Path
from typing import Literal, Union, Callable, Tuple, Optional, Dict, Any, List

project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))


# ============================================================================
# BASE CO-OCCURRENCE METRIC
# ============================================================================
[docs]class BaseCoOccurMetric(ABC): """ Abstract base class for co-occurrence-based bias amplification metrics. This class provides common functionality for computing probabilities and bias amplification computations. """
[docs] def __init__(self): pass
[docs] def computePairProbs(self, A: torch.tensor, T: torch.tensor) -> torch.tensor: """ Computes joint probability for given A and T observations. Parameters ---------- A : torch.tensor Binary tensor of the shape (N x a). a is the number of possible attribute categories. (i.e. 2 for gender {male, female}) T : torch.tensor Binary tensor of the shape (N x t). t is the number of possible task categories. Returns ------- probs : torch.tensor of the shape (a x t). Represents the joint probability for each A-T pair. """ num_obs = A.shape[0] probs = A.T @ T # (num_A, num_obs) x (num_obs, num_T) = (num_A, num_T) probs = probs / num_obs # Works better if multi-class is possible return probs
[docs] def computeProbs(self, vals: torch.tensor) -> torch.tensor: """ Computes observed probability for each category. Parameters ---------- vals : torch.tensor Binary tensor of the shape (N x v). v is the number of possible categories. (i.e. 2 for gender {male, female}) Returns ------- probs : torch.tensor Float tensor representing probabilities for each category. """ probs = vals.mean(axis=0) return probs
[docs] def computeAgivenT(self, A: torch.tensor, T: torch.tensor) -> torch.tensor: """ Computes conditional probability for all Attributes A given T observations. i.e P(A|T) Parameters ---------- A : torch.tensor Binary tensor of the shape (N x a). a is the number of possible attribute categories. (i.e. 2 for gender {male, female}) T : torch.tensor Binary tensor of the shape (N x t). t is the number of possible task categories. Returns ------- probs : torch.tensor of the shape (a x t). Represents the conditional probability P(A|T) for each A-T pair. """ probs = A.T @ T # (num_A, num_obs) x (num_obs, num_T) = (num_A, num_T) probs = probs / probs.sum(axis=0).clamp(min=1e-10) return probs
[docs] def computeTgivenA(self, A: torch.tensor, T: torch.tensor) -> torch.tensor: """ Computes conditional probability for all Task T given A observations. i.e P(T|A) Parameters ---------- A : torch.tensor Binary tensor of the shape (N x a). a is the number of possible attribute categories. (i.e. 2 for gender {male, female}) T : torch.tensor Binary tensor of the shape (N x t). t is the number of possible task categories. Returns ------- probs : torch.tensor of the shape (a x t). Represents the conditional probability P(T|A) for each A-T pair. """ probs = A.T @ T # (num_A, num_obs) x (num_obs, num_T) = (num_A, num_T) probs = probs / probs.sum(axis=1, keepdim=True).clamp(min=1e-10) return probs
[docs] @abstractmethod def computeBiasAmp( self, A: torch.tensor, T: torch.tensor, T_pred: torch.tensor ) -> torch.tensor: """ Abstract method to compute bias amplification. Subclasses must implement this method to compute the bias amplification for each A-T pair. Parameters ---------- A : torch.tensor Binary tensor of shape (N x a) T : torch.tensor Binary tensor of shape (N x t) T_pred : torch.tensor Binary tensor of shape (N x t) Returns ------- bias_amp_combined : torch.tensor Scalar representing mean bias amplification across all pairs bias_amp : torch.tensor Tensor of shape (a x t) representing bias amplification for each A-T pair """ pass
[docs]class BA_MALS(BaseCoOccurMetric):
[docs] def __init__(self): super().__init__()
[docs] def check_bias(self, A: torch.tensor, T: torch.tensor) -> torch.tensor: """ Checks if each A-T pair exhibits statistical dependence (positive correlation). Uses independence test: P(A,T) > P(A)P(T) Parameters ---------- A : torch.tensor Binary tensor of shape (N x a) T : torch.tensor Binary tensor of shape (N x t) - represents ONE attribute combination Returns ------- is_biased : torch.tensor Binary mask of shape (a x t) indicating positively correlated pairs """ P_A_given_T = self.computeAgivenT(A, T) num_A = A.shape[1] is_biased = P_A_given_T > (1 / num_A) return is_biased.float()
[docs] def computeBiasAmp( self, A: torch.tensor, T: torch.tensor, T_pred: torch.tensor ) -> Tuple[torch.tensor, torch.tensor]: """ Computes bias amplification by comparing the conditional probabilities of A given T and A given T_pred. Parameters ---------- A : torch.tensor Binary tensor of shape (N x a) T : torch.tensor Binary tensor of shape (N x t) T_pred : torch.tensor Binary tensor of shape (N x t) Returns ------- bias_amp_combined : torch.tensor Scalar representing mean bias amplification across all pairs bias_amp : torch.tensor Tensor of shape (a x t) representing bias amplification for each A-T pair """ num_T = T.shape[1] A_T_probs = self.computePairProbs(A, T) A_Tpred_probs = self.computePairProbs(A, T_pred) bias_mask = self.check_bias(A, T) bias_amp = (bias_mask * A_Tpred_probs) - (bias_mask * A_T_probs) bias_amp = bias_amp / num_T bias_amp_combined = torch.sum(bias_amp) return bias_amp_combined, bias_amp
[docs]class DBA(BaseCoOccurMetric): """ Bias Amplification Metric from Directional Bias Amplification. This metric computes bias amplification that addresses on the shortcomings of Zhao's metric by focusing on both positive and negative correlations, and the direction of amplification through comparing the conditional probabilities of A given T and A given T_pred. """
[docs] def __init__(self): super().__init__()
[docs] def check_bias(self, A: torch.tensor, T: torch.tensor) -> torch.tensor: """ Checks if each A-T pair exhibits statistical dependence (positive correlation). Uses independence test: P(A,T) > P(A)P(T) Parameters ---------- A : torch.tensor Binary tensor of shape (N x a) T : torch.tensor Binary tensor of shape (N x t) - represents ONE attribute combination Returns ------- y_at : torch.tensor Binary mask of shape (a x t) indicating positively correlated pairs """ joint_probs = self.computePairProbs(A, T) A_probs = self.computeProbs(A).reshape(-1, 1) T_probs = self.computeProbs(T).reshape(-1, 1) independent_probs = A_probs.matmul(T_probs.T) y_at = joint_probs > independent_probs return y_at * 1.0
[docs] def computeBiasAmp( self, A: torch.tensor, T: torch.tensor, T_pred: torch.tensor ) -> Tuple[torch.tensor, torch.tensor]: """ Computes bias amplification by comparing the conditional probabilities of A given T and A given T_pred. Parameters ---------- A : torch.tensor Binary tensor of shape (N x a) T : torch.tensor Binary tensor of shape (N x t) T_pred : torch.tensor Binary tensor of shape (N x t) Returns ------- bias_amp_combined : torch.tensor Scalar representing mean bias amplification across all pairs bias_amp : torch.tensor Tensor of shape (a x t) representing bias amplification for each A-T pair """ num_A = A.shape[1] num_T = T.shape[1] y_at = self.check_bias(A, T) P_T_given_A = self.computeTgivenA(A, T) P_Tpred_given_A = self.computeTgivenA(A, T_pred) # print(f"{P_T_given_A=}") # print(f"{P_Tpred_given_A=}") delta_at = P_Tpred_given_A - P_T_given_A bias_amp = (y_at * delta_at) + ((1 - y_at) * (-1 * delta_at)) bias_amp = bias_amp / (num_A * num_T) bias_amp_combined = torch.sum(bias_amp) return bias_amp_combined, bias_amp
[docs] def computeBiasAmpBidirectional( self, A: torch.tensor, A_pred: torch.tensor, T: torch.tensor, T_pred: torch.tensor, ) -> Dict[str, Tuple[torch.tensor, torch.tensor]]: """ Computes bidirectional bias amplification for AtoT and TtoA directions. Parameters ---------- A : torch.tensor Binary tensor of shape (N x a) A_pred : torch.tensor Binary tensor of shape (N x a) T : torch.tensor Binary tensor of shape (N x t) T_pred : torch.tensor Binary tensor of shape (N x t) Returns ------- bias_amp : dict Dictionary with keys 'AtoT' and 'TtoA', each containing (mean, variance) tuples """ bias_amp_AT = self.computeBiasAmp(A, T, T_pred) bias_amp_TA = self.computeBiasAmp(T, A, A_pred) bias_amp = {"AtoT": bias_amp_AT, "TtoA": bias_amp_TA} return bias_amp
[docs]class MDBA(BaseCoOccurMetric): """ Multi-Attribute Directional Bias Amplification Metric. This metric computes bias amplification that addresses on the shortcomings of DBA by focusing on multi-attribute combinations through comparing the conditional probabilities of A given T and A given T_pred. """
[docs] def __init__(self, min_attr_size: int = 1, max_attr_size: int = None): super().__init__() self.min_attr_size = min_attr_size self.max_attr_size = max_attr_size
[docs] def check_bias(self, A: torch.tensor, T: torch.tensor) -> torch.tensor: """ Checks if each A-T pair exhibits statistical dependence (positive correlation). Uses independence test: P(A,T) > P(A)P(T) Parameters ---------- A : torch.tensor Binary tensor of shape (N x a) T : torch.tensor Binary tensor of shape (N x t) - represents ONE attribute combination Returns ------- y_at : torch.tensor Binary mask of shape (a x t) indicating positively correlated pairs """ joint_probs = self.computePairProbs(A, T) A_probs = self.computeProbs(A).reshape(-1, 1) T_probs = self.computeProbs(T).reshape(-1, 1) independent_probs = A_probs.matmul(T_probs.T) y_at = joint_probs > independent_probs return y_at * 1.0
def _generateAttributeCombinations( self, T: torch.tensor, ) -> List[Tuple[torch.tensor, Tuple]]: """ Generate all combinations of attributes for multi-attribute analysis. Parameters ---------- T : torch.tensor Attribute tensor, shape (N x t) Returns ------- combinations : list[Tuple[torch.tensor, Tuple]] List of (tensor, indices) tuples where: - tensor: binary mask of shape (N x 1) indicating presence of combination - indices: tuple of attribute indices in the combination """ num_T = T.shape[1] min_size = self.min_attr_size max_size = self.max_attr_size if self.max_attr_size else num_T combinations = [] # Generate all possible combinations of attributes from itertools import combinations as iter_combinations for size in range(min_size, min(max_size + 1, num_T + 1)): for combo in iter_combinations(range(num_T), size): # Create binary mask: 1 only if ALL attributes in combo are present combo_mask = torch.ones(T.shape[0], dtype=torch.float) for attr_idx in combo: combo_mask = combo_mask * T[:, attr_idx] # Only include if this combination actually occurs in the data if combo_mask.sum() >= 1: # At least one instance combinations.append((combo_mask.reshape(-1, 1), combo)) return combinations
[docs] def computeBiasAmp( self, A: torch.tensor, T: torch.tensor, T_pred: torch.tensor ) -> Tuple[torch.tensor, torch.tensor]: """ Computes Multi-Dimensional Bias Amplification from A to T. This implements the Multi-> directional metric from the paper (Equation 3). It iterates over ALL combinations of attributes M and computes bias amplification for each combination, then aggregates. The formula from the paper: Multi-> = (mean, variance) where mean = (1 / |G||M|) * Σ_g Σ_m |y_gm * Δ_gm + (1 - y_gm) * |-Δ_gm|| Parameters ---------- A : torch.tensor Ground truth group membership, shape (N x a) T : torch.tensor Ground truth tasks/attributes, shape (N x t) T_pred : torch.tensor Predicted tasks/attributes, shape (N x t) Returns ------- bias_amp_mean : torch.tensor Scalar representing mean bias amplification across all combinations bias_amp_variance : torch.tensor Variance of bias amplification (shows if uniform or concentrated) """ num_A = A.shape[1] # Generate ALL attribute combinations combinations = self._generateAttributeCombinations(T) num_M = len(combinations) if num_M == 0: return torch.tensor(0.0), torch.tensor(0.0) # Store all delta values for variance calculation all_deltas = [] total_bias_amp = 0.0 # Iterate over all attribute combinations m ∈ M for m_tensor, m_indices in combinations: # Create corresponding prediction tensor for this combination m_pred_tensor = torch.ones(T_pred.shape[0], dtype=torch.float) for attr_idx in m_indices: m_pred_tensor = m_pred_tensor * T_pred[:, attr_idx] m_pred_tensor = m_pred_tensor.reshape(-1, 1) # Check which A-m pairs are positively correlated in the data y_am = self.check_bias(A, m_tensor) # Compute conditional probabilities P(m|A) for data and predictions P_m_given_A = self.computeTgivenA(A, m_tensor) P_mpred_given_A = self.computeTgivenA(A, m_pred_tensor) # Calculate change in conditional probability delta_am = P_mpred_given_A - P_m_given_A # Weight by bias direction and take absolute value # For each group g and attribute combination m: # If positively correlated (y_am=1): contribution = |delta| # If negatively correlated (y_am=0): contribution = |-delta| = |delta| weighted_delta = (y_am * delta_am) + ((1 - y_am) * (-delta_am)) abs_weighted_delta = torch.abs(weighted_delta) # Sum over all groups for this attribute combination total_bias_amp += torch.sum(abs_weighted_delta) # Store weighted deltas for variance calculation all_deltas.append(weighted_delta.flatten()) # Normalize by number of groups and attribute combinations bias_amp_mean = total_bias_amp / (num_A * num_M) # Compute variance across all group-attribute pairs if len(all_deltas) > 0: all_deltas_cat = torch.cat(all_deltas) bias_amp_variance = torch.var(all_deltas_cat) else: bias_amp_variance = torch.tensor(0.0) return bias_amp_mean, bias_amp_variance
[docs] def computeBiasAmpBidirectional( self, A: torch.tensor, A_pred: torch.tensor, T: torch.tensor, T_pred: torch.tensor, ) -> Dict[str, Tuple[torch.tensor, torch.tensor]]: """ Computes bidirectional bias amplification. This captures bias amplification in both directions: - Multi_A->T (or Multi_G->M): How group membership (A) influences task predictions (T) - Multi_T->A (or Multi_M->G): How tasks (T) influence group membership predictions (A) Parameters ---------- A : torch.tensor Ground truth group membership A_pred : torch.tensor Predicted group membership T : torch.tensor Ground truth tasks/attributes T_pred : torch.tensor Predicted tasks/attributes Returns ------- bias_amp : dict Dictionary with keys 'AtoT' and 'TtoA', each containing (mean, variance) tuples """ # A->T means: given group A, how does it affect prediction of T bias_amp_AT = self.computeBiasAmp(A, T, T_pred) # T->A means: given task T, how does it affect prediction of A bias_amp_TA = self.computeBiasAmp(T, A, A_pred) bias_amp = {"AtoT": bias_amp_AT, "TtoA": bias_amp_TA} return bias_amp
[docs] def getAttributeCombinationStats( self, T: torch.tensor, ) -> Dict: """ Get statistics about attribute combinations in the dataset. Useful for understanding the dataset structure. Returns ------- stats : dict Dictionary containing: - 'total_combinations': Total number of attribute combinations - 'by_size': Number of combinations for each size - 'examples': Example combinations for each size """ combinations = self._generateAttributeCombinations(T) stats = {"total_combinations": len(combinations), "by_size": {}, "examples": {}} for _, m_indices in combinations: size = len(m_indices) if size not in stats["by_size"]: stats["by_size"][size] = 0 stats["examples"][size] = [] stats["by_size"][size] += 1 if len(stats["examples"][size]) < 5: # Store up to 5 examples stats["examples"][size].append(m_indices) return stats
if __name__ == "__main__": # Data Initialization from utils.datacreator import dataCreator P, D, D2, M1, M2 = dataCreator(16384, 0.2, False, 0.05) P = torch.tensor(P, dtype=torch.float).reshape(-1, 1) P = torch.hstack([P, 1 - P]) D = torch.tensor(D, dtype=torch.float).reshape(-1, 1) D = torch.hstack([D, 1 - D]) D2 = torch.tensor(D2, dtype=torch.float).reshape(-1, 1) D2 = torch.hstack([D2, 1 - D2]) M1 = torch.tensor(M1, dtype=torch.float).reshape(-1, 1) M1 = torch.hstack([M1, 1 - M1]) M2 = torch.tensor(M2, dtype=torch.float).reshape(-1, 1) M2 = torch.hstack([M2, 1 - M2]) # Calculating Params model_1_acc = torch.sum(D == M1) / D.shape[0] model_2_acc = torch.sum(D == M2) / D.shape[0] # Parameter Initialization dpa_obj = DBA() dpa_1 = dpa_obj.computeBiasAmp(P, D, M1) print(f"DPA for case 1: {dpa_1}") print("______________________________________") print("______________________________________") dpa_2 = dpa_obj.computeBiasAmp(P, D, M2) print(f"DPA for case 2: {dpa_2}") print("______________________________________") print("______________________________________") dpa_3 = dpa_obj.computeBiasAmp(P, D2, M1) print(f"DPA for case 3: {dpa_3}") print("______________________________________") print("______________________________________") dpa_4 = dpa_obj.computeBiasAmp(P, D2, M2) print(f"DPA for case 4: {dpa_4}") print("______________________________________") print("______________________________________")