Skip to content

Commit 64c76fe

Browse files
committed
Merge remote-tracking branch 'origin/main'
2 parents 9c4b0cb + 70207f5 commit 64c76fe

File tree

3 files changed

+22
-38
lines changed

3 files changed

+22
-38
lines changed

bites/analyse/analyse_utils.py

+20-36
Original file line numberDiff line numberDiff line change
@@ -129,15 +129,6 @@ def find_nearest_index(array, value):
129129
pred1_cf[i] = surv1_cf.axes[0][find_nearest_index(surv1_cf.iloc[:, i].values, death_probability)]
130130
ITE1 = pred1 - pred1_cf
131131

132-
correct_predicted_probability=None
133-
if best_treatment is not None:
134-
pred_best_choice0 = ((pred0 - pred0_cf) < 0) * 1
135-
pred_best_choice1 = ((pred1 - pred1_cf) > 0) * 1
136-
correct_predicted_probability = np.sum(np.append(best_treatment[treatment == 0] == pred_best_choice0,
137-
best_treatment[treatment == 1] == pred_best_choice1)) \
138-
/ (pred_best_choice0.size + pred_best_choice1.size)
139-
print('Fraction best choice: ' + str(correct_predicted_probability))
140-
141132
ITE = np.zeros(X.shape[0])
142133
k, j = 0, 0
143134
for i in range(X.shape[0]):
@@ -148,29 +139,30 @@ def find_nearest_index(array, value):
148139
ITE[i] = ITE1[j]
149140
j = j + 1
150141

142+
correct_predicted_probability=None
143+
if best_treatment is not None:
144+
correct_predicted_probability=np.sum(best_treatment==(ITE>0)*1)/best_treatment.shape[0]
145+
print('Fraction best choice: ' + str(correct_predicted_probability))
146+
151147
return ITE, correct_predicted_probability
152148

153149
def get_ITE_CFRNet(model, X, treatment, best_treatment=None):
154150

155151
pred,_ = model.predict_numpy(X, treatment)
156152
pred_cf,_ = model.predict_numpy(X, 1-treatment)
157153

158-
correct_predicted_probability=None
159-
if best_treatment is not None:
160-
pred_best_choice0=(pred_cf[treatment==0]-pred[treatment==0]>0) * 1
161-
pred_best_choice1=(pred[treatment==1]-pred_cf[treatment==1]>0) * 1
162-
correct_predicted_probability = np.sum(np.append(best_treatment[treatment == 0] == pred_best_choice0,
163-
best_treatment[treatment == 1] == pred_best_choice1)) \
164-
/ (pred_best_choice0.size + pred_best_choice1.size)
165-
print('Fraction best choice: ' + str(correct_predicted_probability))
166-
167154
ITE = np.zeros(X.shape[0])
168155
for i in range(X.shape[0]):
169156
if treatment[i] == 0:
170157
ITE[i] = pred_cf[i]-pred[i]
171158
else:
172159
ITE[i] = pred[i]-pred_cf[i]
173160

161+
correct_predicted_probability=None
162+
if best_treatment is not None:
163+
correct_predicted_probability=np.sum(best_treatment==(ITE>0)*1)/best_treatment.shape[0]
164+
print('Fraction best choice: ' + str(correct_predicted_probability))
165+
174166
return ITE, correct_predicted_probability
175167

176168

@@ -206,15 +198,6 @@ def find_nearest_index(array, value):
206198
pred1_cf[i] = surv1_cf.axes[0][find_nearest_index(surv1_cf.iloc[:, i].values, death_probability)]
207199
ITE1 = pred1 - pred1_cf
208200

209-
correct_predicted_probability=None
210-
if best_treatment is not None:
211-
pred_best_choice0 = ((pred0 - pred0_cf) < 0) * 1
212-
pred_best_choice1 = ((pred1 - pred1_cf) > 0) * 1
213-
correct_predicted_probability = np.sum(np.append(best_treatment[treatment == 0] == pred_best_choice0,
214-
best_treatment[treatment == 1] == pred_best_choice1)) \
215-
/ (pred_best_choice0.size + pred_best_choice1.size)
216-
print('Fraction best choice: ' + str(correct_predicted_probability))
217-
218201
ITE = np.zeros(X.shape[0])
219202
k, j = 0, 0
220203
for i in range(X.shape[0]):
@@ -225,6 +208,11 @@ def find_nearest_index(array, value):
225208
ITE[i] = ITE1[j]
226209
j = j + 1
227210

211+
correct_predicted_probability=None
212+
if best_treatment is not None:
213+
correct_predicted_probability=np.sum(best_treatment==(ITE>0)*1)/best_treatment.shape[0]
214+
print('Fraction best choice: ' + str(correct_predicted_probability))
215+
228216
return ITE, correct_predicted_probability
229217

230218
def get_ITE_DeepSurv(model, X, treatment, best_treatment=None, death_probability=0.5):
@@ -257,15 +245,6 @@ def find_nearest_index(array, value):
257245
pred1_cf[i] = surv1_cf.axes[0][find_nearest_index(surv1_cf.iloc[:, i].values, death_probability)]
258246
ITE1 = pred1 - pred1_cf
259247

260-
correct_predicted_probability=None
261-
if best_treatment is not None:
262-
pred_best_choice0 = ((pred0 - pred0_cf) < 0) * 1
263-
pred_best_choice1 = ((pred1 - pred1_cf) > 0) * 1
264-
correct_predicted_probability = np.sum(np.append(best_treatment[treatment == 0] == pred_best_choice0,
265-
best_treatment[treatment == 1] == pred_best_choice1)) \
266-
/ (pred_best_choice0.size + pred_best_choice1.size)
267-
print('Fraction best choice: ' + str(correct_predicted_probability))
268-
269248
ITE = np.zeros(X.shape[0])
270249
k, j = 0, 0
271250
for i in range(X.shape[0]):
@@ -276,6 +255,11 @@ def find_nearest_index(array, value):
276255
ITE[i] = ITE1[j]
277256
j = j + 1
278257

258+
correct_predicted_probability=None
259+
if best_treatment is not None:
260+
correct_predicted_probability=np.sum(best_treatment==(ITE>0)*1)/best_treatment.shape[0]
261+
print('Fraction best choice: ' + str(correct_predicted_probability))
262+
279263
return ITE, correct_predicted_probability
280264

281265

bites/model/Fit.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from torch.utils.data import *
1414

1515

16-
def fit (config, X_train, Y_train, event_train, treatment_train=None,**kwargs):
16+
def fit (config, X_train, Y_train, event_train=None, treatment_train=None,**kwargs):
1717
"""
1818
:param config: config file as given in the examples.
1919
:param X_train: np.array(num_samples, features)

examples/RGBSG_analyse.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070

7171
elif method == 'CFRNet':
7272
model, config = get_best_model("ray_results/" + method + "_RGBSG")
73-
pred_ite=get_ITE_CFRNet(model, X_test, treatment_test, best_treatment=None)
73+
pred_ite,_=get_ITE_CFRNet(model, X_test, treatment_test, best_treatment=None)
7474

7575

7676
# The loaded model can be used for further analysis!

0 commit comments

Comments
 (0)