Skip to content

Commit 20b891e

Browse files
author
Eleni Straitouri
committed
add code
1 parent 48cc8dd commit 20b891e

18 files changed

+324302
-0
lines changed

config.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import os
2+
import numpy as np
3+
import argparse
4+
5+
"""Experiments configuration"""
6+
7+
ROOT_DIR = os.path.dirname(__file__)
8+
DATA_PATH = f"{ROOT_DIR}/data"
9+
10+
N_RUNS = 50
11+
DEBUG = False
12+
N_LABELS = 16
13+
14+
parser = argparse.ArgumentParser()
15+
parser.add_argument("--model_name", type=str, \
16+
choices=["alexnet", "vgg19", "densenet161", "googlenet", "resnet152"], \
17+
help='Choose the model used by the set valued predictor.', \
18+
default='vgg19')
19+
parser.add_argument("--noise_level", choices=[80, 95, 110, 125], type=int, \
20+
help='Choose the noise level applied in the images. \
21+
Datatsets are available only for the following noise levels 80, 95, 110, and 125',\
22+
default=110)
23+
parser.add_argument("--alpha", type=float, default=0.1,\
24+
help="The desired risk controlling level")
25+
parser.add_argument("--mode", choices=['control', 'test'], type=str,\
26+
help="Choose the operating mode: control (finding the harm controlling lambdas) or\
27+
test (computing accuracy and harm over test set).",
28+
default='control')
29+
parser.add_argument("--dataset", choices=["PS",""],type=str,\
30+
help="Set to PS for experiments using the \
31+
ImageNet16H-PS dataset. Set only for \
32+
model vgg19 and noise level 110.",\
33+
default="")
34+
parser.add_argument("--mnl_ps", choices=[True, False],\
35+
type=bool, help="Set to True to run experiments\
36+
with a mixture of MNLs using the ImageNet16H-PS data",
37+
default=False)
38+
39+
args,unknown = parser.parse_known_args()
40+
41+
# Operating mode (test or control)
42+
mode = args.mode
43+
44+
# The classifier used to produce the prediction sets
45+
model_name = args.model_name
46+
47+
# The noise level applied to the images (higher noise, more difficult task)
48+
noise_level = args.noise_level
49+
50+
# Set harm control level
51+
alpha = args.alpha
52+
53+
# Granularity of harm level grid
54+
alpha_step = 0.01
55+
56+
# Select dataset
57+
HUMAN_DATASET = args.dataset
58+
59+
# Select data for MNL mixture
60+
MNL_PS = args.mnl_ps
61+
62+
# Initialize random generators
63+
entropy = 0x3034c61a9ae04ff8cb62ab8ec2c4b501
64+
numpy_rng = np.random.default_rng(entropy)
65+
66+
# Fraction of the dataset to use as calibration set
67+
calibration_split = 0.1
68+
69+
# The granularity for the lambda grid
70+
lambda_step = 0.001
71+
72+
# Set up paths for results and plots
73+
# Path to store results
74+
results_path = f"results"
75+
76+
# Path to store plots
77+
plot_path = f"plots"

counterfactual_harm.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import config
2+
import numpy as np
3+
import os
4+
import pandas as pd
5+
6+
"""Main module to control and evaluate counterfactual harm (bound)"""
7+
8+
class CounterfactualHarm:
9+
def __init__(self, model, human, data) -> None:
10+
self.model = model
11+
self.human = human
12+
self.data = data
13+
self.emp_risks = {}
14+
self.H = self.fn_per_set("h")
15+
self.G = self.fn_per_set("g")
16+
self.alphas = np.arange(0,1+config.alpha_step, config.alpha_step)
17+
18+
def _empirical_risk(self, lamda):
19+
"""Computes hat H(\lambda)"""
20+
# Set is valid or corner case of empty set
21+
is_valid_set = (self.H_data_set['model_score_true_label'] >= (1 - lamda)) | (self.H_data_set['model_score_max'] == self.H_data_set['model_score_true_label'])
22+
if config.mode == 'control':
23+
emp_risk = self.H_data_set.where(is_valid_set, self.H_data_set['h_invalid_sets'], axis=0)['h_valid_sets'].mean()
24+
return emp_risk
25+
else:
26+
harm = self.H_data_set.where(is_valid_set, self.H_data_set['h_invalid_sets'], axis=0)['h_valid_sets']
27+
harm_sum = harm.sum()
28+
harm_count = harm.count()
29+
return harm_sum, harm_count
30+
31+
32+
def _empirical_benefit(self, lamda):
33+
"""Computes hat G(\lambda)"""
34+
# Set is valid or corner case of empty set
35+
is_valid_set = (self.G_data_set['model_score_true_label'] >= (1 - lamda)) | (self.H_data_set['model_score_max'] == self.H_data_set['model_score_true_label'])
36+
37+
if config.mode == 'control':
38+
emp_ben = self.G_data_set.where(is_valid_set, self.G_data_set['g_invalid_sets'], axis=0)['g_valid_sets'].mean()
39+
return emp_ben
40+
else:
41+
benefit = self.G_data_set.where(is_valid_set, self.G_data_set['g_invalid_sets'], axis=0)['g_valid_sets']
42+
g_sum = benefit.sum()
43+
g_count = benefit.count()
44+
return (g_sum, g_count)
45+
46+
47+
def fn_per_set(self, fn_name):
48+
"""Reads/computes the h/g function for each prediction set"""
49+
data_path = f"{config.ROOT_DIR}/data/{fn_name}/noise{config.noise_level}"
50+
file_path = f"{data_path}/{config.model_name}{'_pred_set' if config.HUMAN_DATASET=='PS' else ''}.csv"
51+
if not os.path.exists(file_path):
52+
if not os.path.exists(data_path):
53+
os.makedirs(data_path)
54+
fn_per_set = []
55+
# Compute the h/g value given each human prediction
56+
for image_name, participant_id, human_correct in self.human.itertuples(index=True):
57+
true_label = self.data.loc[image_name]["category"]
58+
model_score_true_label = self.model.loc[image_name][true_label]
59+
60+
model_score_max = self.model.drop(columns=['correct']).loc[image_name].max()
61+
62+
if fn_name == "h":
63+
fn_value_valid = 0
64+
fn_value_invalid = human_correct
65+
else:
66+
fn_value_valid = 1 - human_correct
67+
fn_value_invalid = 0
68+
69+
fn_per_set.append((image_name, participant_id, model_score_true_label, model_score_max, fn_value_valid, fn_value_invalid))
70+
71+
columns = ["image_name", "participant_id", "model_score_true_label", "model_score_max", f"{fn_name}_valid_sets", f"{fn_name}_invalid_sets"]
72+
73+
fn_df = pd.DataFrame(fn_per_set, columns=columns).set_index('image_name')
74+
75+
fn_df.to_csv(file_path)
76+
else:
77+
fn_df = pd.read_csv(file_path, index_col='image_name')
78+
79+
return fn_df
80+
81+
82+
def set_data_set(self, data_set):
83+
self.data_set = data_set
84+
self.data_set_size = len(data_set)
85+
self.emp_risks = {}
86+
self.H_data_set = self.H.loc[self.data_set.index.values]
87+
self.G_data_set = self.G.loc[self.data_set.index.values]
88+
89+
def control(self):
90+
"""Min control level per lambda for h and g"""
91+
n = self.data_set_size
92+
thresholds = (((n+1)*self.alphas - 1)/n)
93+
lamdas_dict = { np.around(lamda, decimals=3):{} for lamda in np.arange(0,1+config.lambda_step,config.lambda_step)}
94+
95+
for lamda in np.arange(0,1+config.lambda_step, config.lambda_step):
96+
emp_risk_lamda = self._empirical_risk(lamda)
97+
98+
# Min alpha such that lambda is harm controlling under CF (Counterfactual) monotonicity
99+
min_alpha_idx_CF = np.searchsorted(thresholds, emp_risk_lamda, side='left')
100+
101+
# For each lambda keep the min level of control under CF monotonicity
102+
lamdas_dict[np.around(lamda, decimals=3)]['CF'] = np.round(self.alphas[min_alpha_idx_CF], decimals=2)
103+
104+
# Empirical benefit (\hat G)
105+
emp_benefit_lamda = self._empirical_benefit(1.-lamda)
106+
107+
# Select smallest alpha that for which lambda is g controlling under cI (Interventional) monotonicity
108+
min_alpha_idx_cI = np.searchsorted(thresholds, emp_benefit_lamda, side='left')
109+
lamdas_dict[np.around(1 - lamda, decimals=3)]['cI'] = np.round(self.alphas[min_alpha_idx_cI], decimals=2)
110+
111+
return lamdas_dict
112+
113+
def compute(self):
114+
"""Evaluate the counterfactual harm (bound)"""
115+
lamdas_dict = { lamda:{} for lamda in np.arange(0,1+config.lambda_step,config.lambda_step)}
116+
117+
for lamda in np.arange(0,1+config.lambda_step, config.lambda_step):
118+
harm_sum, harm_count = self._empirical_risk(lamda)
119+
g_harm, g_count = self._empirical_benefit(lamda)
120+
lamdas_dict[lamda]['hat_H'] = (harm_sum, harm_count)
121+
lamdas_dict[lamda]['hat_G'] = (g_harm, g_count)
122+
123+
return lamdas_dict

0 commit comments

Comments
 (0)