Skip to content

Commit 9b850e0

Browse files
author
Eleni Straitouri
committed
Extend experiments to use set-valued predictors using the SAPS scores
1 parent b22af67 commit 9b850e0

8 files changed

+288
-53
lines changed

config.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
N_RUNS = 50
1111
DEBUG = False
1212
N_LABELS = 16
13+
SAPS_WEIGHTS = [0.02, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35]
1314

1415
parser = argparse.ArgumentParser()
1516
parser.add_argument("--model_name", type=str, \
@@ -35,6 +36,10 @@
3536
type=bool, help="Set to True to run experiments\
3637
with a mixture of MNLs using the ImageNet16H-PS data",
3738
default=False)
39+
parser.add_argument("--score_type", choices=["vanilla", "SAPS"], \
40+
default="vanilla", type=str,
41+
help="Choose the label score function type (non-conformity score)\
42+
to select the labels in the prediction set.")
3843

3944
args,unknown = parser.parse_known_args()
4045

@@ -62,12 +67,27 @@
6267
# Initialize random generators
6368
entropy = 0x3034c61a9ae04ff8cb62ab8ec2c4b501
6469
numpy_rng = np.random.default_rng(entropy)
70+
tune_rng = np.random.default_rng(entropy)
71+
saps_rng = np.random.default_rng(entropy)
6572

6673
# Fraction of the dataset to use as calibration set
6774
calibration_split = 0.1
6875

76+
# Fraction of the dataset to use as a tuning set for SAPS
77+
tuning_split = 0.1
78+
79+
# Name of score function (non-conformity score)
80+
score_type = args.score_type
81+
82+
# Range of lambda values according to score type
83+
lambda_min = 0.
84+
lambda_max = 1. if score_type == 'vanilla' else (1. + SAPS_WEIGHTS[-1]*(N_LABELS-1))
85+
6986
# The granularity for the lambda grid
70-
lambda_step = 0.001
87+
lambda_step = (lambda_max - lambda_min) / 1000
88+
89+
# Decimals to round up spurious lambda values due to np.arange
90+
lamda_dec = 3 if score_type == 'vanilla' else 5
7191

7292
# Set up paths for results and plots
7393
# Path to store results

counterfactual_harm.py

Lines changed: 98 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,38 @@
22
import numpy as np
33
import os
44
import pandas as pd
5-
5+
import json
6+
from utils import saps, saps_batch, create_path
7+
from copy import deepcopy
68
"""Main module to control and evaluate counterfactual harm (bound)"""
79

810
class CounterfactualHarm:
9-
def __init__(self, model, human, data) -> None:
11+
def __init__(self, model, human, data, score_type='vanilla') -> None:
1012
self.model = model
1113
self.human = human
1214
self.data = data
1315
self.emp_risks = {}
16+
self.score_type = score_type
1417
self.H = self.fn_per_set("h")
1518
self.G = self.fn_per_set("g")
1619
self.alphas = np.arange(0,1+config.alpha_step, config.alpha_step)
1720

1821
def _empirical_risk(self, lamda):
1922
"""Computes hat H(\lambda)"""
2023
# Set is valid or corner case of empty set
21-
is_valid_set = (self.H_data_set['model_score_true_label'] >= (1 - lamda)) | (self.H_data_set['model_score_max'] == self.H_data_set['model_score_true_label'])
24+
if self.score_type == 'vanilla':
25+
is_valid_set = (self.H_data_set['model_score_true_label'] >= (1 - lamda)) | (self.H_data_set['model_score_max'] == self.H_data_set['model_score_true_label'])
26+
else:
27+
lamda_k = str(np.around(lamda, decimals=config.lamda_dec))
28+
assert lamda_k in self.weights, self.weights.keys()
29+
is_valid_set = (saps(
30+
self.weights[lamda_k],
31+
self.H_data_set['true_label_rank'],
32+
self.H_data_set['model_score_max'],
33+
config.saps_rng
34+
) <= lamda) | \
35+
(self.H_data_set['true_label_rank'] == 1)
36+
2237
if config.mode == 'control':
2338
emp_risk = self.H_data_set.where(is_valid_set, self.H_data_set['h_invalid_sets'], axis=0)['h_valid_sets'].mean()
2439
return emp_risk
@@ -32,7 +47,18 @@ def _empirical_risk(self, lamda):
3247
def _empirical_benefit(self, lamda):
3348
"""Computes hat G(\lambda)"""
3449
# Set is valid or corner case of empty set
35-
is_valid_set = (self.G_data_set['model_score_true_label'] >= (1 - lamda)) | (self.H_data_set['model_score_max'] == self.H_data_set['model_score_true_label'])
50+
if self.score_type == 'vanilla':
51+
is_valid_set = (self.G_data_set['model_score_true_label'] >= (1 - lamda)) | (self.H_data_set['model_score_max'] == self.H_data_set['model_score_true_label'])
52+
else:
53+
lamda_k = str(np.around(lamda, decimals=config.lamda_dec))
54+
is_valid_set = (saps(
55+
self.weights[lamda_k],
56+
self.G_data_set['true_label_rank'],
57+
self.G_data_set['model_score_max'],
58+
config.saps_rng
59+
) <= lamda) | \
60+
(self.G_data_set['true_label_rank'] == 1)
61+
3662

3763
if config.mode == 'control':
3864
emp_ben = self.G_data_set.where(is_valid_set, self.G_data_set['g_invalid_sets'], axis=0)['g_valid_sets'].mean()
@@ -46,39 +72,38 @@ def _empirical_benefit(self, lamda):
4672

4773
def fn_per_set(self, fn_name):
4874
"""Reads/computes the h/g function for each prediction set"""
49-
data_path = f"{config.ROOT_DIR}/data/{fn_name}/noise{config.noise_level}"
75+
data_path = f"{config.ROOT_DIR}/data/{fn_name}_{self.score_type}/noise{config.noise_level}"
5076
file_path = f"{data_path}/{config.model_name}{'_pred_set' if config.HUMAN_DATASET=='PS' else ''}.csv"
5177
if not os.path.exists(file_path):
5278
if not os.path.exists(data_path):
5379
os.makedirs(data_path)
5480
fn_per_set = []
81+
5582
# Compute the h/g value given each human prediction
5683
for image_name, participant_id, human_correct in self.human.itertuples(index=True):
5784
true_label = self.data.loc[image_name]["category"]
5885
model_score_true_label = self.model.loc[image_name][true_label]
59-
86+
label_ranks = config.N_LABELS - self.model.drop(columns=['correct']).loc[image_name].argsort().argsort()
87+
true_label_rank = label_ranks[true_label]
6088
model_score_max = self.model.drop(columns=['correct']).loc[image_name].max()
6189

62-
if fn_name == "h":
90+
if "h" in fn_name:
6391
fn_value_valid = 0
6492
fn_value_invalid = human_correct
6593
else:
6694
fn_value_valid = 1 - human_correct
6795
fn_value_invalid = 0
6896

69-
fn_per_set.append((image_name, participant_id, model_score_true_label, model_score_max, fn_value_valid, fn_value_invalid))
97+
fn_per_set.append((image_name, participant_id, model_score_true_label, model_score_max, true_label_rank, fn_value_valid, fn_value_invalid))
7098

71-
columns = ["image_name", "participant_id", "model_score_true_label", "model_score_max", f"{fn_name}_valid_sets", f"{fn_name}_invalid_sets"]
72-
73-
fn_df = pd.DataFrame(fn_per_set, columns=columns).set_index('image_name')
74-
99+
columns = ["image_name", "participant_id", "model_score_true_label", "model_score_max", "true_label_rank",f"{fn_name}_valid_sets", f"{fn_name}_invalid_sets"]
100+
fn_df = pd.DataFrame(fn_per_set, columns=columns).set_index('image_name')
75101
fn_df.to_csv(file_path)
76102
else:
77103
fn_df = pd.read_csv(file_path, index_col='image_name')
78104

79105
return fn_df
80106

81-
82107
def set_data_set(self, data_set):
83108
self.data_set = data_set
84109
self.data_set_size = len(data_set)
@@ -90,34 +115,86 @@ def control(self):
90115
"""Min control level per lambda for h and g"""
91116
n = self.data_set_size
92117
thresholds = (((n+1)*self.alphas - 1)/n)
93-
lamdas_dict = { np.around(lamda, decimals=3):{} for lamda in np.arange(0,1+config.lambda_step,config.lambda_step)}
118+
lamdas_dict = {np.around(lamda, decimals=config.lamda_dec):{} for lamda in np.arange(config.lambda_min,config.lambda_max+config.lambda_step,config.lambda_step)}
94119

95-
for lamda in np.arange(0,1+config.lambda_step, config.lambda_step):
120+
for lamda in np.arange(config.lambda_min,config.lambda_max+config.lambda_step, config.lambda_step):
96121
emp_risk_lamda = self._empirical_risk(lamda)
97122

98123
# Min alpha such that lambda is harm controlling under CF (Counterfactual) monotonicity
99124
min_alpha_idx_CF = np.searchsorted(thresholds, emp_risk_lamda, side='left')
100125

101126
# For each lambda keep the min level of control under CF monotonicity
102-
lamdas_dict[np.around(lamda, decimals=3)]['CF'] = np.round(self.alphas[min_alpha_idx_CF], decimals=2)
127+
lamdas_dict[np.around(lamda, decimals=config.lamda_dec)]['CF'] = np.round(self.alphas[min_alpha_idx_CF], decimals=2)
103128

104129
# Empirical benefit (\hat G)
105-
emp_benefit_lamda = self._empirical_benefit(1.-lamda)
130+
emp_benefit_lamda = self._empirical_benefit(config.lambda_max-lamda)
106131

107132
# Select smallest alpha that for which lambda is g controlling under cI (Interventional) monotonicity
108133
min_alpha_idx_cI = np.searchsorted(thresholds, emp_benefit_lamda, side='left')
109-
lamdas_dict[np.around(1 - lamda, decimals=3)]['cI'] = np.round(self.alphas[min_alpha_idx_cI], decimals=2)
134+
lamdas_dict[np.around(config.lambda_max - lamda, decimals=config.lamda_dec)]['cI'] = np.round(self.alphas[min_alpha_idx_cI], decimals=2)
110135

111136
return lamdas_dict
112137

113138
def compute(self):
114139
"""Evaluate the counterfactual harm (bound)"""
115-
lamdas_dict = { lamda:{} for lamda in np.arange(0,1+config.lambda_step,config.lambda_step)}
140+
lamdas_dict = {np.around(lamda, decimals=config.lamda_dec):{} for lamda in np.arange(config.lambda_min,config.lambda_max+config.lambda_step,config.lambda_step)}
116141

117-
for lamda in np.arange(0,1+config.lambda_step, config.lambda_step):
142+
for lamda in lamdas_dict.keys():
118143
harm_sum, harm_count = self._empirical_risk(lamda)
119144
g_harm, g_count = self._empirical_benefit(lamda)
120145
lamdas_dict[lamda]['hat_H'] = (harm_sum, harm_count)
121146
lamdas_dict[lamda]['hat_G'] = (g_harm, g_count)
122147

123-
return lamdas_dict
148+
return lamdas_dict
149+
150+
def tune_saps(self, x_y, run):
151+
"""Find the optimal saps weight parameter value for each lamda value in a given run"""
152+
self.weights = {}
153+
weight_per_run_path = f"{config.results_path}/{config.model_name}/noise{config.noise_level}/saps/saps_weights_tune{config.tuning_split}_run{run}.json"
154+
# Read the weights if already computed
155+
if os.path.exists(weight_per_run_path):
156+
with open(weight_per_run_path, 'rt') as f:
157+
self.weights = json.load(f)
158+
else:
159+
self.weight_per_run = {}
160+
lamdas = np.array([lamda for lamda in np.arange(config.lambda_min,config.lambda_max+config.lambda_step,config.lambda_step)])
161+
min_avg_size_per_lambda = np.ones_like(lamdas)*config.N_LABELS
162+
self.weight_per_lamda = np.ones_like(lamdas)*config.SAPS_WEIGHTS[0]
163+
tune_set_probs = self.model.drop(columns=['correct']).loc[x_y.index]
164+
tune_set_label_ranks = deepcopy(tune_set_probs)
165+
for idx in x_y.index:
166+
tune_set_label_ranks.loc[idx] = tune_set_probs.loc[idx].argsort()
167+
max_scores = tune_set_probs.max(axis=1)
168+
169+
for weight in config.SAPS_WEIGHTS:
170+
# Get the saps scores in the tune set
171+
saps_scores_per_label = saps_batch(
172+
weight,
173+
tune_set_probs.loc[x_y.index.to_numpy()],
174+
tune_set_label_ranks,
175+
max_scores
176+
)
177+
# Compute the average prediction set size given a weight value
178+
set_sizes_per_lambda = []
179+
for idx in saps_scores_per_label.index:
180+
set_sizes_per_lambda.append(np.searchsorted(
181+
saps_scores_per_label.loc[idx],
182+
lamdas,
183+
side='right',
184+
sorter=saps_scores_per_label.loc[idx].argsort()
185+
))
186+
set_sizes_per_lambda = np.stack(set_sizes_per_lambda, ).T
187+
set_sizes_per_lambda[set_sizes_per_lambda == 0] = 1
188+
avg_set_sizes_per_lambda = set_sizes_per_lambda.mean(axis=1)
189+
# Keep track of the weight value with the minimum average set size
190+
self.weight_per_lamda[
191+
avg_set_sizes_per_lambda <
192+
min_avg_size_per_lambda] = weight
193+
min_avg_size_per_lambda = np.minimum(avg_set_sizes_per_lambda, min_avg_size_per_lambda)
194+
# Save the optimal weight for each lambda
195+
for i,lamda in enumerate(lamdas):
196+
lam_key = str(np.around(lamda, decimals=config.lamda_dec))
197+
self.weights[lam_key] = self.weight_per_lamda[i]
198+
create_path(f"{config.results_path}/{config.model_name}/noise{config.noise_level}/saps/")
199+
with open(weight_per_run_path, 'w') as f:
200+
json.dump(self.weights,f)

0 commit comments

Comments
 (0)