-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathmain.py
154 lines (126 loc) · 5.26 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
"""
Main file to run the CNN_GRU model
Two modes for training:
AE: training with adversarial examples
without_AE: regular training
Within each mode of training there are two modes for model evaluation:
cv: one seizure leave out cross validation
test: devide the data to training evaluation and testing
"""
import tensorflow as tf
from utils.load_signals import PrepData
from utils.load_results import summary_results, load_results, auc_results
from models.model_ae import CNN_GRU
import os
import numpy as np
import sys
import warnings
import argparse
if not sys.warnoptions:
warnings.simplefilter("ignore")
def data_loading(target, dataset,settings ):
""" Extract the data from .edf files, prepare the data using PrepData class
and then save it in cachedir
Params:
target (str) : number of patient
dataset (str) : name of dataset to load
Returns:
Preprocessed EEG ictal and preictal data
"""
print('Data Loading...................................')
ictal_X, ictal_y = PrepData(target, type='ictal',
settings=settings).apply()
interictal_X, interictal_y = PrepData(target, type='interictal',
settings=settings).apply()
return ictal_X, ictal_y, interictal_X, interictal_y, settings
def get_args():
# Get some basic command line arguements
parser = argparse.ArgumentParser()
parser.add_argument('-d', '--dataset', help='Epilepsy dataset [CHBMIT or FB]', type=str, default='CHBMIT')
parser.add_argument('-val', '--validation', help='Validation training mode [cv, test]', type=str, default='cv')
parser.add_argument('-m', '--mode', help='Training mode augmentstion [AE, without_AE]', type=str, default='AE')
parser.add_argument('-b', '--batch_size', help='batch size', type=int, default=256)
parser.add_argument('-e', '--epoch', help='Full training epochs', type=int, default=50)
parser.add_argument('-p', '--percentage', help='Percentage of the AE of the total data to generate', type=int, default=40)
parser.add_argument('-v', '--verbose', type=bool, default=False)
return parser.parse_args()
def train(args):
dataset = args.dataset
if dataset == 'CHBMIT':
patients = [
'1',
'2',
'3',
'5',
'9',
'10',
'13',
'14',
'18',
'19',
'20',
'21',
'23'
]
elif dataset == 'FB':
patients = [
'1',
'3',
'4',
'5',
'6',
'14',
'15',
'16',
'17',
'18',
'19',
'20',
'21'
]
settings = {"dataset": dataset,
"datadir": dataset,
"cachedir": "{}_cache".format(dataset),
"results": "results\\results_{}\\".format(dataset),
"resultsCV": "results\\results_CV_{}\\".format(dataset),
"resultsCV_AE": "results\\resultsCV_{}_AE\\".format(dataset),
"results_AE": "results\\results_{}_AE\\".format(dataset)}
for i in range(len(patients)):
target = patients[i]
# loading the data for each patient
ictal_X, ictal_y, interictal_X,\
interictal_y, settings = data_loading(target, dataset, settings)
# Resetting the graph
tf.reset_default_graph()
graph = tf.get_default_graph()
session = tf.Session()
noise_limit = 0.3 #limit of the max value for generated AE examples
model = CNN_GRU([ictal_X.shape[2], ictal_X.shape[3]],
dataset, noise_limit, graph) # Build the graph
model.train_eval_test_ae(
session, target, ictal_X, ictal_y, interictal_X,
interictal_y, settings, validation=args.validation,
mode=args.mode, batch_size=args.batch_size, epoch=args.epoch,
percentage=args.percentage, verbose=args.verbose)
session.close()
return settings
def main():
args = get_args()
settings = train(args)
#print the results
print("\n")
print('************ Final Results on {} Dataset *********************'.format(args.dataset))
if args.validation == 'cv':
results_path = settings['resultsCV']
elif args.validation == 'cv' and args.mode == 'AE':
results_path = settings['resultsCV_AE']
elif args.validation == 'test':
results_path = settings['results']
elif args.validation == 'test' and args.mode == 'AE':
results_path = settings['results_AE']
os.makedirs(results_path, exist_ok=True)
data_results, patients=load_results(results_path, args.dataset)
summary_results(patients, data_results)
print("AVG_AUC: ", np.mean(auc_results(data_results, patients)))
if __name__ == "__main__":
main()