Skip to content

Commit 55410f0

Browse files
committed
first commit
1 parent 05dcf8d commit 55410f0

38 files changed

+2670
-17
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
/bites/data/RGBSG/rgbsg.h5

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ and analyse the findings with
4545
The complete workflow for BITES, DeepSurv and CFRNet are completely controllable by setting the ``config`` parameters
4646
````python
4747
config = {
48-
"Method": 'BITES', #'ITES', 'DeepSurv', 'CFRNet'
48+
"Method": 'bites', #'ITES', 'DeepSurv', 'CFRNet'
4949
"trial_name": 'RGBSG',
5050
"result_dir": './ray_results',
5151
"val_set_fraction": 0.2,

bites/__init__.py

Whitespace-only changes.

bites/analyse/__init__.py

Whitespace-only changes.

bites/analyse/analyse_utils.py

Lines changed: 325 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,325 @@
1+
import os
2+
3+
import matplotlib.pyplot as plt
4+
import numpy as np
5+
import pandas as pd
6+
import seaborn as sns
7+
import torch
8+
from BITES.model.BITES_base import BITES
9+
from BITES.model.CFRNet_base import CFRNet
10+
from BITES.model.DeepSurv_base import DeepSurv
11+
from BITES.utils.eval_surv import EvalSurv
12+
from lifelines import KaplanMeierFitter
13+
from lifelines.statistics import logrank_test
14+
from ray.tune import Analysis
15+
16+
17+
def get_best_model(path_to_experiment="./ray_results/test_hydra", assign_treatment=None):
18+
analysis = Analysis(path_to_experiment, default_metric="val_loss", default_mode="min")
19+
best_config = analysis.get_best_config()
20+
best_checkpoint_dir = analysis.get_best_checkpoint(analysis.get_best_logdir())
21+
22+
if best_config["Method"] == 'BITES' or best_config["Method"] == 'ITES':
23+
best_net = BITES(best_config["num_covariates"], best_config["shared_layer"], best_config["individual_layer"],
24+
out_features=1,
25+
dropout=best_config["dropout"])
26+
27+
elif best_config["Method"] == 'DeepSurv' or best_config["Method"] == 'DeepSurvT':
28+
best_net = DeepSurv(best_config["num_covariates"], best_config["shared_layer"], out_features=1,
29+
dropout=best_config["dropout"])
30+
best_net.treatment = assign_treatment
31+
32+
elif best_config["Method"] == 'CFRNet':
33+
best_net = CFRNet(best_config["num_covariates"], best_config["shared_layer"], out_features=1,
34+
dropout=best_config["dropout"])
35+
36+
else:
37+
print('Method not implemented yet!')
38+
return
39+
40+
model_state, optimizer_state = torch.load(os.path.join(
41+
best_checkpoint_dir, "checkpoint"), map_location=torch.device('cpu'))
42+
43+
best_net.load_state_dict(model_state)
44+
45+
return best_net, best_config
46+
47+
48+
def get_C_Index_BITES(model, X, time, event, treatment):
49+
if not model.baseline_hazards_:
50+
print('Compute Baseline Hazards before running get_C_index')
51+
return
52+
53+
surv0, surv1 = model.predict_surv_df(X, treatment)
54+
surv = pd.concat([surv0, surv1], axis=1)
55+
surv = surv.interpolate('index')
56+
surv = surv.iloc[:-100, :].fillna(1)
57+
surv = surv.iloc[100:, :].fillna(0)
58+
C_index0 = EvalSurv(surv0, time[treatment == 0], event[treatment == 0], censor_surv='km').concordance_td()
59+
C_index1 = EvalSurv(surv1, time[treatment == 1], event[treatment == 1], censor_surv='km').concordance_td()
60+
C_index = EvalSurv(surv, np.append(time[treatment == 0], time[treatment == 1]),
61+
np.append(event[treatment == 0], event[treatment == 1]),
62+
censor_surv='km').concordance_td()
63+
64+
print('Time dependent C-Index: ' + str(C_index)[:5])
65+
print('Case0 C-Index: ' + str(C_index0)[:5])
66+
print('Case1 C-Index: ' + str(C_index1)[:5])
67+
68+
return C_index, C_index0, C_index1
69+
70+
71+
def get_C_Index_DeepSurvT(model0, model1, X, time, event, treatment):
72+
73+
mask0 = treatment == 0
74+
mask1 = treatment == 1
75+
76+
X0, time0, event0 = X[mask0], time[mask0], event[mask0]
77+
X1, time1, event1 = X[mask1], time[mask1], event[mask1]
78+
surv0 = model0.predict_surv_df(X0)
79+
surv1 = model1.predict_surv_df(X1)
80+
81+
surv = pd.concat([surv0, surv1], axis=1)
82+
surv = surv.interpolate('index')
83+
surv = surv.iloc[:-100, :].fillna(1)
84+
surv = surv.iloc[100:, :].fillna(0)
85+
C_index = EvalSurv(surv, np.append(time0, time1),
86+
np.append(event0, event1), censor_surv='km').concordance_td()
87+
C_index0 = EvalSurv(surv0, time0, event0, censor_surv='km').concordance_td()
88+
C_index1 = EvalSurv(surv1, time1, event1, censor_surv='km').concordance_td()
89+
90+
print('Time dependent C-Index: ' + str(C_index)[:5])
91+
print('Case0 C-Index: ' + str(C_index0)[:5])
92+
print('Case1 C-Index: ' + str(C_index1)[:5])
93+
94+
return C_index, C_index0, C_index1
95+
96+
97+
def get_ITE_BITES(model, X, treatment, best_treatment=None, death_probability=0.5):
98+
if not model.baseline_hazards_:
99+
print('Compute Baseline Hazards before running get_ITE()')
100+
return
101+
102+
def find_nearest_index(array, value):
103+
idx = (np.abs(array - value)).argmin()
104+
return idx
105+
106+
surv0, surv1 = model.predict_surv_df(X, treatment)
107+
surv0_cf, surv1_cf = model.predict_surv_counterfactual_df(X, treatment)
108+
109+
"""Find factual and counterfactual prediction: Value at 50% survival probability"""
110+
pred0 = np.zeros(surv0.shape[1])
111+
pred0_cf = np.zeros(surv0.shape[1])
112+
for i in range(surv0.shape[1]):
113+
pred0[i] = surv0.axes[0][find_nearest_index(surv0.iloc[:, i].values, death_probability)]
114+
pred0_cf[i] = surv0_cf.axes[0][find_nearest_index(surv0_cf.iloc[:, i].values, death_probability)]
115+
ITE0 = pred0_cf - pred0
116+
117+
pred1 = np.zeros(surv1.shape[1])
118+
pred1_cf = np.zeros(surv1.shape[1])
119+
for i in range(surv1.shape[1]):
120+
pred1[i] = surv1.axes[0][find_nearest_index(surv1.iloc[:, i].values, death_probability)]
121+
pred1_cf[i] = surv1_cf.axes[0][find_nearest_index(surv1_cf.iloc[:, i].values, death_probability)]
122+
ITE1 = pred1 - pred1_cf
123+
124+
correct_predicted_probability=None
125+
if best_treatment:
126+
pred_best_choice0 = ((pred0 - pred0_cf) < 0) * 1
127+
pred_best_choice1 = ((pred1 - pred1_cf) > 0) * 1
128+
correct_predicted_probability = np.sum(np.append(best_treatment[treatment == 0] == pred_best_choice0,
129+
best_treatment[treatment == 1, 3] == pred_best_choice1)) \
130+
/ (pred_best_choice0.size + pred_best_choice1.size)
131+
print('Fraction best choice: ' + str(correct_predicted_probability))
132+
133+
ITE = np.zeros(X.shape[0])
134+
k, j = 0, 0
135+
for i in range(X.shape[0]):
136+
if treatment[i] == 0:
137+
ITE[i] = ITE0[k]
138+
k = k + 1
139+
else:
140+
ITE[i] = ITE1[j]
141+
j = j + 1
142+
143+
return ITE, correct_predicted_probability
144+
145+
def get_ITE_CFRNet(model, X, treatment, best_treatment=None):
146+
147+
pred,_ = model.predict_numpy(X, treatment)
148+
pred_cf,_ = model.predict_numpy(X, 1-treatment)
149+
150+
correct_predicted_probability=None
151+
if best_treatment:
152+
pred_best_choice0=(pred_cf[treatment==0]-pred[treatment==0]>0) * 1
153+
pred_best_choice1=(pred[treatment==1]-pred_cf[treatment==1]>0) * 1
154+
correct_predicted_probability = np.sum(np.append(best_treatment[treatment == 0] == pred_best_choice0,
155+
best_treatment[treatment == 1, 3] == pred_best_choice1)) \
156+
/ (pred_best_choice0.size + pred_best_choice1.size)
157+
print('Fraction best choice: ' + str(correct_predicted_probability))
158+
159+
ITE = np.zeros(X.shape[0])
160+
for i in range(X.shape[0]):
161+
if treatment[i] == 0:
162+
ITE[i] = pred_cf[i]-pred[i]
163+
else:
164+
ITE[i] = pred[i]-pred_cf[i]
165+
166+
return ITE, correct_predicted_probability
167+
168+
169+
170+
171+
def get_ITE_DeepSurvT(model0, model1, X, treatment, best_treatment=None, death_probability=0.5):
172+
def find_nearest_index(array, value):
173+
idx = (np.abs(array - value)).argmin()
174+
return idx
175+
176+
mask0 = treatment == 0
177+
mask1 = treatment == 1
178+
179+
X0 = X[mask0]
180+
X1 = X[mask1]
181+
surv0 = model0.predict_surv_df(X0)
182+
surv0_cf = model1.predict_surv_df(X0)
183+
surv1 = model1.predict_surv_df(X1)
184+
surv1_cf = model0.predict_surv_df(X1)
185+
186+
"""Find factual and counterfactual prediction: Value at 50% survival probability"""
187+
pred0 = np.zeros(surv0.shape[1])
188+
pred0_cf = np.zeros(surv0.shape[1])
189+
for i in range(surv0.shape[1]):
190+
pred0[i] = surv0.axes[0][find_nearest_index(surv0.iloc[:, i].values, death_probability)]
191+
pred0_cf[i] = surv0_cf.axes[0][find_nearest_index(surv0_cf.iloc[:, i].values, death_probability)]
192+
ITE0 = pred0_cf - pred0
193+
194+
pred1 = np.zeros(surv1.shape[1])
195+
pred1_cf = np.zeros(surv1.shape[1])
196+
for i in range(surv1.shape[1]):
197+
pred1[i] = surv1.axes[0][find_nearest_index(surv1.iloc[:, i].values, death_probability)]
198+
pred1_cf[i] = surv1_cf.axes[0][find_nearest_index(surv1_cf.iloc[:, i].values, death_probability)]
199+
ITE1 = pred1 - pred1_cf
200+
201+
correct_predicted_probability=None
202+
if best_treatment:
203+
pred_best_choice0 = ((pred0 - pred0_cf) < 0) * 1
204+
pred_best_choice1 = ((pred1 - pred1_cf) > 0) * 1
205+
correct_predicted_probability = np.sum(np.append(best_treatment[treatment == 0] == pred_best_choice0,
206+
best_treatment[treatment == 1, 3] == pred_best_choice1)) \
207+
/ (pred_best_choice0.size + pred_best_choice1.size)
208+
print('Fraction best choice: ' + str(correct_predicted_probability))
209+
210+
ITE = np.zeros(X.shape[0])
211+
k, j = 0, 0
212+
for i in range(X.shape[0]):
213+
if treatment[i] == 0:
214+
ITE[i] = ITE0[k]
215+
k = k + 1
216+
else:
217+
ITE[i] = ITE1[j]
218+
j = j + 1
219+
220+
return ITE, correct_predicted_probability
221+
222+
def get_ITE_DeepSurvT(model0, model1, X, treatment, best_treatment=None, death_probability=0.5):
223+
def find_nearest_index(array, value):
224+
idx = (np.abs(array - value)).argmin()
225+
return idx
226+
227+
mask0 = treatment == 0
228+
mask1 = treatment == 1
229+
230+
X0 = X[mask0]
231+
X1 = X[mask1]
232+
surv0 = model0.predict_surv_df(X0)
233+
surv0_cf = model1.predict_surv_df(X0)
234+
surv1 = model1.predict_surv_df(X1)
235+
surv1_cf = model0.predict_surv_df(X1)
236+
237+
"""Find factual and counterfactual prediction: Value at 50% survival probability"""
238+
pred0 = np.zeros(surv0.shape[1])
239+
pred0_cf = np.zeros(surv0.shape[1])
240+
for i in range(surv0.shape[1]):
241+
pred0[i] = surv0.axes[0][find_nearest_index(surv0.iloc[:, i].values, death_probability)]
242+
pred0_cf[i] = surv0_cf.axes[0][find_nearest_index(surv0_cf.iloc[:, i].values, death_probability)]
243+
ITE0 = pred0_cf - pred0
244+
245+
pred1 = np.zeros(surv1.shape[1])
246+
pred1_cf = np.zeros(surv1.shape[1])
247+
for i in range(surv1.shape[1]):
248+
pred1[i] = surv1.axes[0][find_nearest_index(surv1.iloc[:, i].values, death_probability)]
249+
pred1_cf[i] = surv1_cf.axes[0][find_nearest_index(surv1_cf.iloc[:, i].values, death_probability)]
250+
ITE1 = pred1 - pred1_cf
251+
252+
correct_predicted_probability=None
253+
if best_treatment:
254+
pred_best_choice0 = ((pred0 - pred0_cf) < 0) * 1
255+
pred_best_choice1 = ((pred1 - pred1_cf) > 0) * 1
256+
correct_predicted_probability = np.sum(np.append(best_treatment[treatment == 0] == pred_best_choice0,
257+
best_treatment[treatment == 1, 3] == pred_best_choice1)) \
258+
/ (pred_best_choice0.size + pred_best_choice1.size)
259+
print('Fraction best choice: ' + str(correct_predicted_probability))
260+
261+
ITE = np.zeros(X.shape[0])
262+
k, j = 0, 0
263+
for i in range(X.shape[0]):
264+
if treatment[i] == 0:
265+
ITE[i] = ITE0[k]
266+
k = k + 1
267+
else:
268+
ITE[i] = ITE1[j]
269+
j = j + 1
270+
271+
return ITE, correct_predicted_probability
272+
273+
274+
def analyse_randomized_test_set(pred_ite, Y_test, event_test, treatment_test, C_index=None, method_name='set_name', save_path=None,new_figure=True,annotate=True):
275+
mask_recommended = (pred_ite > 0) == treatment_test
276+
mask_antirecommended = (pred_ite < 0) == treatment_test
277+
278+
recommended_times = Y_test[mask_recommended]
279+
recommended_event = event_test[mask_recommended]
280+
antirecommended_times = Y_test[mask_antirecommended]
281+
antirecommended_event = event_test[mask_antirecommended]
282+
283+
logrank_result = logrank_test(recommended_times, antirecommended_times, recommended_event, antirecommended_event, alpha=0.95)
284+
logrank_result.print_summary(style='ascii', decimals=4)
285+
286+
colors = sns.color_palette()
287+
kmf = KaplanMeierFitter()
288+
kmf_cf = KaplanMeierFitter()
289+
if method_name==None:
290+
kmf.fit(recommended_times, recommended_event, label='Treated')
291+
kmf_cf.fit(antirecommended_times, antirecommended_event, label='Control')
292+
else:
293+
kmf.fit(recommended_times, recommended_event, label=method_name + ' Recommendation')
294+
kmf_cf.fit(antirecommended_times, antirecommended_event, label=method_name + ' Anti-Recommendation')
295+
296+
297+
if new_figure:
298+
#plt.figure(figsize=(8, 2.7))
299+
#kmf.plot(c=colors[0])
300+
#kmf_cf.plot(c=colors[1])
301+
302+
kmf.plot(c=colors[0],ci_show=False)
303+
kmf_cf.plot(c=colors[1],ci_show=False)
304+
else:
305+
kmf.plot(c=colors[2])
306+
kmf_cf.plot(c=colors[3])
307+
308+
309+
if annotate:
310+
# Calculate p-value text position and display.
311+
y_pos = 0.4
312+
plt.text(1 * 3, y_pos, f"$p$ = {logrank_result.p_value:.6f}", fontsize='small')
313+
fraction2 = np.sum((pred_ite > 0)) / pred_ite.shape[0]
314+
plt.text(1 * 3, 0.3, 'C-Index=' + str(C_index)[:5], fontsize='small')
315+
plt.text(1 * 3, 0.2, f"{fraction2 * 100:.1f}% recommended for T=1", fontsize='small')
316+
317+
plt.xlabel('Survival Time [month]')
318+
plt.ylabel('Survival Probability')
319+
320+
plt.tight_layout()
321+
if save_path:
322+
plt.savefig(save_path, format='pdf')
323+
324+
325+

bites/data/RGBSG/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
## Data Description
2+
3+
The RGBSG dataset can be obtained from https://github.com/arturomoncadatorres/deepsurvk/tree/master/deepsurvk/datasets/data.
4+
To use the example files place `rgbsg.h5` in this folder

0 commit comments

Comments
 (0)