2
2
import numpy as np
3
3
import os
4
4
import pandas as pd
5
-
5
+ import json
6
+ from utils import saps , saps_batch , create_path
7
+ from copy import deepcopy
6
8
"""Main module to control and evaluate counterfactual harm (bound)"""
7
9
8
10
class CounterfactualHarm :
9
- def __init__ (self , model , human , data ) -> None :
11
+ def __init__ (self , model , human , data , score_type = 'vanilla' ) -> None :
10
12
self .model = model
11
13
self .human = human
12
14
self .data = data
13
15
self .emp_risks = {}
16
+ self .score_type = score_type
14
17
self .H = self .fn_per_set ("h" )
15
18
self .G = self .fn_per_set ("g" )
16
19
self .alphas = np .arange (0 ,1 + config .alpha_step , config .alpha_step )
17
20
18
21
def _empirical_risk (self , lamda ):
19
22
"""Computes hat H(\lambda)"""
20
23
# Set is valid or corner case of empty set
21
- is_valid_set = (self .H_data_set ['model_score_true_label' ] >= (1 - lamda )) | (self .H_data_set ['model_score_max' ] == self .H_data_set ['model_score_true_label' ])
24
+ if self .score_type == 'vanilla' :
25
+ is_valid_set = (self .H_data_set ['model_score_true_label' ] >= (1 - lamda )) | (self .H_data_set ['model_score_max' ] == self .H_data_set ['model_score_true_label' ])
26
+ else :
27
+ lamda_k = str (np .around (lamda , decimals = config .lamda_dec ))
28
+ assert lamda_k in self .weights , self .weights .keys ()
29
+ is_valid_set = (saps (
30
+ self .weights [lamda_k ],
31
+ self .H_data_set ['true_label_rank' ],
32
+ self .H_data_set ['model_score_max' ],
33
+ config .saps_rng
34
+ ) <= lamda ) | \
35
+ (self .H_data_set ['true_label_rank' ] == 1 )
36
+
22
37
if config .mode == 'control' :
23
38
emp_risk = self .H_data_set .where (is_valid_set , self .H_data_set ['h_invalid_sets' ], axis = 0 )['h_valid_sets' ].mean ()
24
39
return emp_risk
@@ -32,7 +47,18 @@ def _empirical_risk(self, lamda):
32
47
def _empirical_benefit (self , lamda ):
33
48
"""Computes hat G(\lambda)"""
34
49
# Set is valid or corner case of empty set
35
- is_valid_set = (self .G_data_set ['model_score_true_label' ] >= (1 - lamda )) | (self .H_data_set ['model_score_max' ] == self .H_data_set ['model_score_true_label' ])
50
+ if self .score_type == 'vanilla' :
51
+ is_valid_set = (self .G_data_set ['model_score_true_label' ] >= (1 - lamda )) | (self .H_data_set ['model_score_max' ] == self .H_data_set ['model_score_true_label' ])
52
+ else :
53
+ lamda_k = str (np .around (lamda , decimals = config .lamda_dec ))
54
+ is_valid_set = (saps (
55
+ self .weights [lamda_k ],
56
+ self .G_data_set ['true_label_rank' ],
57
+ self .G_data_set ['model_score_max' ],
58
+ config .saps_rng
59
+ ) <= lamda ) | \
60
+ (self .G_data_set ['true_label_rank' ] == 1 )
61
+
36
62
37
63
if config .mode == 'control' :
38
64
emp_ben = self .G_data_set .where (is_valid_set , self .G_data_set ['g_invalid_sets' ], axis = 0 )['g_valid_sets' ].mean ()
@@ -46,39 +72,38 @@ def _empirical_benefit(self, lamda):
46
72
47
73
def fn_per_set (self , fn_name ):
48
74
"""Reads/computes the h/g function for each prediction set"""
49
- data_path = f"{ config .ROOT_DIR } /data/{ fn_name } /noise{ config .noise_level } "
75
+ data_path = f"{ config .ROOT_DIR } /data/{ fn_name } _ { self . score_type } /noise{ config .noise_level } "
50
76
file_path = f"{ data_path } /{ config .model_name } { '_pred_set' if config .HUMAN_DATASET == 'PS' else '' } .csv"
51
77
if not os .path .exists (file_path ):
52
78
if not os .path .exists (data_path ):
53
79
os .makedirs (data_path )
54
80
fn_per_set = []
81
+
55
82
# Compute the h/g value given each human prediction
56
83
for image_name , participant_id , human_correct in self .human .itertuples (index = True ):
57
84
true_label = self .data .loc [image_name ]["category" ]
58
85
model_score_true_label = self .model .loc [image_name ][true_label ]
59
-
86
+ label_ranks = config .N_LABELS - self .model .drop (columns = ['correct' ]).loc [image_name ].argsort ().argsort ()
87
+ true_label_rank = label_ranks [true_label ]
60
88
model_score_max = self .model .drop (columns = ['correct' ]).loc [image_name ].max ()
61
89
62
- if fn_name == "h" :
90
+ if "h" in fn_name :
63
91
fn_value_valid = 0
64
92
fn_value_invalid = human_correct
65
93
else :
66
94
fn_value_valid = 1 - human_correct
67
95
fn_value_invalid = 0
68
96
69
- fn_per_set .append ((image_name , participant_id , model_score_true_label , model_score_max , fn_value_valid , fn_value_invalid ))
97
+ fn_per_set .append ((image_name , participant_id , model_score_true_label , model_score_max , true_label_rank , fn_value_valid , fn_value_invalid ))
70
98
71
- columns = ["image_name" , "participant_id" , "model_score_true_label" , "model_score_max" , f"{ fn_name } _valid_sets" , f"{ fn_name } _invalid_sets" ]
72
-
73
- fn_df = pd .DataFrame (fn_per_set , columns = columns ).set_index ('image_name' )
74
-
99
+ columns = ["image_name" , "participant_id" , "model_score_true_label" , "model_score_max" , "true_label_rank" ,f"{ fn_name } _valid_sets" , f"{ fn_name } _invalid_sets" ]
100
+ fn_df = pd .DataFrame (fn_per_set , columns = columns ).set_index ('image_name' )
75
101
fn_df .to_csv (file_path )
76
102
else :
77
103
fn_df = pd .read_csv (file_path , index_col = 'image_name' )
78
104
79
105
return fn_df
80
106
81
-
82
107
def set_data_set (self , data_set ):
83
108
self .data_set = data_set
84
109
self .data_set_size = len (data_set )
@@ -90,34 +115,86 @@ def control(self):
90
115
"""Min control level per lambda for h and g"""
91
116
n = self .data_set_size
92
117
thresholds = (((n + 1 )* self .alphas - 1 )/ n )
93
- lamdas_dict = { np .around (lamda , decimals = 3 ):{} for lamda in np .arange (0 , 1 + config .lambda_step ,config .lambda_step )}
118
+ lamdas_dict = {np .around (lamda , decimals = config . lamda_dec ):{} for lamda in np .arange (config . lambda_min , config . lambda_max + config .lambda_step ,config .lambda_step )}
94
119
95
- for lamda in np .arange (0 , 1 + config .lambda_step , config .lambda_step ):
120
+ for lamda in np .arange (config . lambda_min , config . lambda_max + config .lambda_step , config .lambda_step ):
96
121
emp_risk_lamda = self ._empirical_risk (lamda )
97
122
98
123
# Min alpha such that lambda is harm controlling under CF (Counterfactual) monotonicity
99
124
min_alpha_idx_CF = np .searchsorted (thresholds , emp_risk_lamda , side = 'left' )
100
125
101
126
# For each lambda keep the min level of control under CF monotonicity
102
- lamdas_dict [np .around (lamda , decimals = 3 )]['CF' ] = np .round (self .alphas [min_alpha_idx_CF ], decimals = 2 )
127
+ lamdas_dict [np .around (lamda , decimals = config . lamda_dec )]['CF' ] = np .round (self .alphas [min_alpha_idx_CF ], decimals = 2 )
103
128
104
129
# Empirical benefit (\hat G)
105
- emp_benefit_lamda = self ._empirical_benefit (1. - lamda )
130
+ emp_benefit_lamda = self ._empirical_benefit (config . lambda_max - lamda )
106
131
107
132
# Select smallest alpha that for which lambda is g controlling under cI (Interventional) monotonicity
108
133
min_alpha_idx_cI = np .searchsorted (thresholds , emp_benefit_lamda , side = 'left' )
109
- lamdas_dict [np .around (1 - lamda , decimals = 3 )]['cI' ] = np .round (self .alphas [min_alpha_idx_cI ], decimals = 2 )
134
+ lamdas_dict [np .around (config . lambda_max - lamda , decimals = config . lamda_dec )]['cI' ] = np .round (self .alphas [min_alpha_idx_cI ], decimals = 2 )
110
135
111
136
return lamdas_dict
112
137
113
138
def compute (self ):
114
139
"""Evaluate the counterfactual harm (bound)"""
115
- lamdas_dict = { lamda :{} for lamda in np .arange (0 , 1 + config .lambda_step ,config .lambda_step )}
140
+ lamdas_dict = {np . around ( lamda , decimals = config . lamda_dec ) :{} for lamda in np .arange (config . lambda_min , config . lambda_max + config .lambda_step ,config .lambda_step )}
116
141
117
- for lamda in np . arange ( 0 , 1 + config . lambda_step , config . lambda_step ):
142
+ for lamda in lamdas_dict . keys ( ):
118
143
harm_sum , harm_count = self ._empirical_risk (lamda )
119
144
g_harm , g_count = self ._empirical_benefit (lamda )
120
145
lamdas_dict [lamda ]['hat_H' ] = (harm_sum , harm_count )
121
146
lamdas_dict [lamda ]['hat_G' ] = (g_harm , g_count )
122
147
123
- return lamdas_dict
148
+ return lamdas_dict
149
+
150
+ def tune_saps (self , x_y , run ):
151
+ """Find the optimal saps weight parameter value for each lamda value in a given run"""
152
+ self .weights = {}
153
+ weight_per_run_path = f"{ config .results_path } /{ config .model_name } /noise{ config .noise_level } /saps/saps_weights_tune{ config .tuning_split } _run{ run } .json"
154
+ # Read the weights if already computed
155
+ if os .path .exists (weight_per_run_path ):
156
+ with open (weight_per_run_path , 'rt' ) as f :
157
+ self .weights = json .load (f )
158
+ else :
159
+ self .weight_per_run = {}
160
+ lamdas = np .array ([lamda for lamda in np .arange (config .lambda_min ,config .lambda_max + config .lambda_step ,config .lambda_step )])
161
+ min_avg_size_per_lambda = np .ones_like (lamdas )* config .N_LABELS
162
+ self .weight_per_lamda = np .ones_like (lamdas )* config .SAPS_WEIGHTS [0 ]
163
+ tune_set_probs = self .model .drop (columns = ['correct' ]).loc [x_y .index ]
164
+ tune_set_label_ranks = deepcopy (tune_set_probs )
165
+ for idx in x_y .index :
166
+ tune_set_label_ranks .loc [idx ] = tune_set_probs .loc [idx ].argsort ()
167
+ max_scores = tune_set_probs .max (axis = 1 )
168
+
169
+ for weight in config .SAPS_WEIGHTS :
170
+ # Get the saps scores in the tune set
171
+ saps_scores_per_label = saps_batch (
172
+ weight ,
173
+ tune_set_probs .loc [x_y .index .to_numpy ()],
174
+ tune_set_label_ranks ,
175
+ max_scores
176
+ )
177
+ # Compute the average prediction set size given a weight value
178
+ set_sizes_per_lambda = []
179
+ for idx in saps_scores_per_label .index :
180
+ set_sizes_per_lambda .append (np .searchsorted (
181
+ saps_scores_per_label .loc [idx ],
182
+ lamdas ,
183
+ side = 'right' ,
184
+ sorter = saps_scores_per_label .loc [idx ].argsort ()
185
+ ))
186
+ set_sizes_per_lambda = np .stack (set_sizes_per_lambda , ).T
187
+ set_sizes_per_lambda [set_sizes_per_lambda == 0 ] = 1
188
+ avg_set_sizes_per_lambda = set_sizes_per_lambda .mean (axis = 1 )
189
+ # Keep track of the weight value with the minimum average set size
190
+ self .weight_per_lamda [
191
+ avg_set_sizes_per_lambda <
192
+ min_avg_size_per_lambda ] = weight
193
+ min_avg_size_per_lambda = np .minimum (avg_set_sizes_per_lambda , min_avg_size_per_lambda )
194
+ # Save the optimal weight for each lambda
195
+ for i ,lamda in enumerate (lamdas ):
196
+ lam_key = str (np .around (lamda , decimals = config .lamda_dec ))
197
+ self .weights [lam_key ] = self .weight_per_lamda [i ]
198
+ create_path (f"{ config .results_path } /{ config .model_name } /noise{ config .noise_level } /saps/" )
199
+ with open (weight_per_run_path , 'w' ) as f :
200
+ json .dump (self .weights ,f )
0 commit comments