1
+ import config
2
+ import numpy as np
3
+ import os
4
+ import pandas as pd
5
+
6
+ """Main module to control and evaluate counterfactual harm (bound)"""
7
+
8
+ class CounterfactualHarm :
9
+ def __init__ (self , model , human , data ) -> None :
10
+ self .model = model
11
+ self .human = human
12
+ self .data = data
13
+ self .emp_risks = {}
14
+ self .H = self .fn_per_set ("h" )
15
+ self .G = self .fn_per_set ("g" )
16
+ self .alphas = np .arange (0 ,1 + config .alpha_step , config .alpha_step )
17
+
18
+ def _empirical_risk (self , lamda ):
19
+ """Computes hat H(\lambda)"""
20
+ # Set is valid or corner case of empty set
21
+ is_valid_set = (self .H_data_set ['model_score_true_label' ] >= (1 - lamda )) | (self .H_data_set ['model_score_max' ] == self .H_data_set ['model_score_true_label' ])
22
+ if config .mode == 'control' :
23
+ emp_risk = self .H_data_set .where (is_valid_set , self .H_data_set ['h_invalid_sets' ], axis = 0 )['h_valid_sets' ].mean ()
24
+ return emp_risk
25
+ else :
26
+ harm = self .H_data_set .where (is_valid_set , self .H_data_set ['h_invalid_sets' ], axis = 0 )['h_valid_sets' ]
27
+ harm_sum = harm .sum ()
28
+ harm_count = harm .count ()
29
+ return harm_sum , harm_count
30
+
31
+
32
+ def _empirical_benefit (self , lamda ):
33
+ """Computes hat G(\lambda)"""
34
+ # Set is valid or corner case of empty set
35
+ is_valid_set = (self .G_data_set ['model_score_true_label' ] >= (1 - lamda )) | (self .H_data_set ['model_score_max' ] == self .H_data_set ['model_score_true_label' ])
36
+
37
+ if config .mode == 'control' :
38
+ emp_ben = self .G_data_set .where (is_valid_set , self .G_data_set ['g_invalid_sets' ], axis = 0 )['g_valid_sets' ].mean ()
39
+ return emp_ben
40
+ else :
41
+ benefit = self .G_data_set .where (is_valid_set , self .G_data_set ['g_invalid_sets' ], axis = 0 )['g_valid_sets' ]
42
+ g_sum = benefit .sum ()
43
+ g_count = benefit .count ()
44
+ return (g_sum , g_count )
45
+
46
+
47
+ def fn_per_set (self , fn_name ):
48
+ """Reads/computes the h/g function for each prediction set"""
49
+ data_path = f"{ config .ROOT_DIR } /data/{ fn_name } /noise{ config .noise_level } "
50
+ file_path = f"{ data_path } /{ config .model_name } { '_pred_set' if config .HUMAN_DATASET == 'PS' else '' } .csv"
51
+ if not os .path .exists (file_path ):
52
+ if not os .path .exists (data_path ):
53
+ os .makedirs (data_path )
54
+ fn_per_set = []
55
+ # Compute the h/g value given each human prediction
56
+ for image_name , participant_id , human_correct in self .human .itertuples (index = True ):
57
+ true_label = self .data .loc [image_name ]["category" ]
58
+ model_score_true_label = self .model .loc [image_name ][true_label ]
59
+
60
+ model_score_max = self .model .drop (columns = ['correct' ]).loc [image_name ].max ()
61
+
62
+ if fn_name == "h" :
63
+ fn_value_valid = 0
64
+ fn_value_invalid = human_correct
65
+ else :
66
+ fn_value_valid = 1 - human_correct
67
+ fn_value_invalid = 0
68
+
69
+ fn_per_set .append ((image_name , participant_id , model_score_true_label , model_score_max , fn_value_valid , fn_value_invalid ))
70
+
71
+ columns = ["image_name" , "participant_id" , "model_score_true_label" , "model_score_max" , f"{ fn_name } _valid_sets" , f"{ fn_name } _invalid_sets" ]
72
+
73
+ fn_df = pd .DataFrame (fn_per_set , columns = columns ).set_index ('image_name' )
74
+
75
+ fn_df .to_csv (file_path )
76
+ else :
77
+ fn_df = pd .read_csv (file_path , index_col = 'image_name' )
78
+
79
+ return fn_df
80
+
81
+
82
+ def set_data_set (self , data_set ):
83
+ self .data_set = data_set
84
+ self .data_set_size = len (data_set )
85
+ self .emp_risks = {}
86
+ self .H_data_set = self .H .loc [self .data_set .index .values ]
87
+ self .G_data_set = self .G .loc [self .data_set .index .values ]
88
+
89
+ def control (self ):
90
+ """Min control level per lambda for h and g"""
91
+ n = self .data_set_size
92
+ thresholds = (((n + 1 )* self .alphas - 1 )/ n )
93
+ lamdas_dict = { np .around (lamda , decimals = 3 ):{} for lamda in np .arange (0 ,1 + config .lambda_step ,config .lambda_step )}
94
+
95
+ for lamda in np .arange (0 ,1 + config .lambda_step , config .lambda_step ):
96
+ emp_risk_lamda = self ._empirical_risk (lamda )
97
+
98
+ # Min alpha such that lambda is harm controlling under CF (Counterfactual) monotonicity
99
+ min_alpha_idx_CF = np .searchsorted (thresholds , emp_risk_lamda , side = 'left' )
100
+
101
+ # For each lambda keep the min level of control under CF monotonicity
102
+ lamdas_dict [np .around (lamda , decimals = 3 )]['CF' ] = np .round (self .alphas [min_alpha_idx_CF ], decimals = 2 )
103
+
104
+ # Empirical benefit (\hat G)
105
+ emp_benefit_lamda = self ._empirical_benefit (1. - lamda )
106
+
107
+ # Select smallest alpha that for which lambda is g controlling under cI (Interventional) monotonicity
108
+ min_alpha_idx_cI = np .searchsorted (thresholds , emp_benefit_lamda , side = 'left' )
109
+ lamdas_dict [np .around (1 - lamda , decimals = 3 )]['cI' ] = np .round (self .alphas [min_alpha_idx_cI ], decimals = 2 )
110
+
111
+ return lamdas_dict
112
+
113
+ def compute (self ):
114
+ """Evaluate the counterfactual harm (bound)"""
115
+ lamdas_dict = { lamda :{} for lamda in np .arange (0 ,1 + config .lambda_step ,config .lambda_step )}
116
+
117
+ for lamda in np .arange (0 ,1 + config .lambda_step , config .lambda_step ):
118
+ harm_sum , harm_count = self ._empirical_risk (lamda )
119
+ g_harm , g_count = self ._empirical_benefit (lamda )
120
+ lamdas_dict [lamda ]['hat_H' ] = (harm_sum , harm_count )
121
+ lamdas_dict [lamda ]['hat_G' ] = (g_harm , g_count )
122
+
123
+ return lamdas_dict
0 commit comments