Skip to content

Commit a8cba70

Browse files
committed
Initial commit
0 parents  commit a8cba70

File tree

89 files changed

+4910
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

89 files changed

+4910
-0
lines changed

.DS_Store

6 KB
Binary file not shown.

Dataset/Loaders/__init__.py

Whitespace-only changes.
131 Bytes
Binary file not shown.
154 Bytes
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

Dataset/Loaders/hcpRestNew.py

+148
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
2+
import pickle
3+
import pandas
4+
import numpy as np
5+
import random
6+
import torch
7+
8+
9+
def loadTorchSave(atlas):
10+
11+
baseFolderName = None ### replace with your data directory
12+
13+
if(atlas == "AAL"):
14+
fileName = baseFolderName + "/hcpRest_aal.save"
15+
elif(atlas == "Schaefer"):
16+
fileName = baseFolderName + "/hcpRest_schaefer.save"
17+
18+
19+
subjectDict = torch.load(fileName)
20+
21+
subjectDatas = []
22+
subjectIds = []
23+
24+
for subjectId in subjectDict:
25+
26+
subjectData = subjectDict[subjectId]
27+
28+
if subjectData.shape[0] != 1200:
29+
print("Passing short subject")
30+
continue
31+
32+
subjectIds.append(subjectId)
33+
subjectDatas.append(subjectData.T)
34+
35+
36+
return subjectDatas, subjectIds
37+
38+
39+
40+
def getLabels(subjectIds, targetTask):
41+
42+
43+
temp = pandas.read_csv(".../Datasets/HCP_1200/Preprocessed/pheno.csv").to_numpy() ### replace with the pheno.csv file directory
44+
45+
phenoInfos = {}
46+
for row in temp:
47+
phenoInfos[str(row[0])] = {"gender": row[3], "age" : row[4], "fIQ" : row[121]}
48+
49+
labels = []
50+
ages = []
51+
52+
badSubjIds = []
53+
54+
for subjectId in subjectIds:
55+
56+
label = phenoInfos[subjectId][targetTask]
57+
58+
agePheno = phenoInfos[subjectId]["age"]
59+
if("-" not in agePheno):
60+
age = float(agePheno.split("+")[0])
61+
else:
62+
age = (float(agePheno.split("-")[0]))
63+
64+
ages.append(age)
65+
66+
if(targetTask == "gender"):
67+
68+
label = 1 if label == 'M' else 0
69+
70+
if(targetTask == "age"):
71+
72+
if("-" not in label):
73+
label = float(label.split("+")[0])
74+
else:
75+
label = (float(label.split("-")[0]) + float(label.split("-")[1])) / 2.0
76+
77+
if(targetTask == "fIQ"):
78+
if(np.isnan(label)):
79+
badSubjIds.append(subjectId)
80+
81+
labels.append(label)
82+
83+
return labels, badSubjIds, ages
84+
85+
86+
def hcpRestLoader(atlas, targetTask):
87+
88+
89+
90+
91+
if(atlas == "AAL" or atlas == "Schaefer"):
92+
93+
subjectDatas_, subjectIds_ = loadTorchSave(atlas)
94+
95+
96+
if(targetTask != None):
97+
labels_, badSubjIds, ages_ = getLabels(subjectIds_, targetTask)
98+
99+
subjectDatas = []
100+
subjectIds = []
101+
102+
if(targetTask != None):
103+
labels = []
104+
ages = []
105+
106+
for i, subjectId in enumerate(subjectIds_):
107+
if(not subjectId in badSubjIds):
108+
subjectDatas.append(subjectDatas_[i])
109+
subjectIds.append(subjectIds_[i])
110+
111+
ages.append(ages_[i])
112+
labels.append(labels_[i])
113+
114+
else:
115+
116+
subjectDatas = subjectDatas_
117+
subjectIds = subjectIds_
118+
119+
classWeights = []
120+
if(targetTask == "gender"):
121+
for i in range(np.max(labels) + 1):
122+
classWeights.append(float(np.sum(np.array(labels) == i)))
123+
classWeights = 1/np.array(classWeights)
124+
classWeights = classWeights / np.sum(classWeights)
125+
126+
127+
128+
random.Random(12).shuffle(subjectDatas)
129+
random.Random(12).shuffle(subjectIds)
130+
131+
if(targetTask != None):
132+
random.Random(12).shuffle(labels)
133+
random.Random(12).shuffle(ages)
134+
135+
136+
137+
if(targetTask != None):
138+
print("hcp rest data : # subjects = {}, chance level = {}".format(len(labels), np.sum(labels) / len(labels)))
139+
140+
141+
if(targetTask != None):
142+
if(targetTask == "gender"):
143+
return subjectDatas, labels, subjectIds
144+
else:
145+
return subjectDatas, labels, subjectIds, classWeights, ages, None, None
146+
else:
147+
return subjectDatas, subjectIds
148+

Dataset/__init__.py

Whitespace-only changes.
146 Bytes
Binary file not shown.
3.48 KB
Binary file not shown.

Dataset/dataset.py

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
2+
from torch.utils.data import Dataset, DataLoader
3+
from sklearn.model_selection import StratifiedKFold
4+
from random import shuffle, randrange
5+
import numpy as np
6+
import random
7+
8+
from .Loaders.hcpRestNew import hcpRestLoader
9+
10+
11+
loaderMapper = {
12+
"hcpRest" : hcpRestLoader,
13+
# add other datasets if you want
14+
}
15+
16+
def getDataset(mainOptions):
17+
18+
# if(mainOptions.supervision == "supervised"):
19+
return SupervisedDataset(mainOptions)
20+
21+
22+
23+
class SupervisedDataset(Dataset):
24+
25+
def __init__(self, mainOptions):
26+
27+
self.batchSize = mainOptions.batchSize
28+
self.dynamicLength = mainOptions.dynamicLength
29+
self.foldCount = mainOptions.kFold
30+
31+
loader = loaderMapper[mainOptions.datasets[0]]
32+
33+
self.kFold = StratifiedKFold(mainOptions.kFold, shuffle=True, random_state=0) if mainOptions.kFold is not None else None
34+
self.k = None
35+
36+
self.data, self.labels, self.subjectIds = loader(mainOptions.atlas, mainOptions.targetTask)
37+
38+
self.targetData = None
39+
self.targetLabel = None
40+
41+
def __len__(self):
42+
return len(self.data) if isinstance(self.targetData, type(None)) else len(self.targetData)
43+
44+
def get_nOfTrains_perFold(self):
45+
46+
return len(self.data)
47+
48+
def setFold(self, fold, train=True):
49+
50+
self.k = fold
51+
self.train = train
52+
53+
if(self.kFold == None): # if this is the case, train must be True
54+
trainIdx = list(range(len(self.data)))
55+
else:
56+
trainIdx, testIdx = list(self.kFold.split(self.data, self.labels))[fold]
57+
58+
random.Random(12).shuffle(trainIdx)
59+
60+
self.targetData = [self.data[idx] for idx in trainIdx] if train else [self.data[idx] for idx in testIdx]
61+
self.targetLabels = [self.labels[idx] for idx in trainIdx] if train else [self.labels[idx] for idx in testIdx]
62+
self.targetSubjIds = [self.subjectIds[idx] for idx in trainIdx] if train else [self.subjectIds[idx] for idx in testIdx]
63+
64+
def getFold(self, fold, train=True):
65+
66+
self.setFold(fold, train)
67+
68+
if(train):
69+
return DataLoader(self, batch_size=self.batchSize, shuffle=False)
70+
else:
71+
return DataLoader(self, batch_size=1, shuffle=False)
72+
73+
74+
def __getitem__(self, idx):
75+
76+
subject = self.targetData[idx]
77+
label = self.targetLabels[idx]
78+
subjId = self.targetSubjIds[idx]
79+
80+
# normalize timeseries
81+
timeseries = subject # (numberOfRois, time)
82+
timeseries = (timeseries - np.mean(timeseries, axis=1, keepdims=True)) / np.std(timeseries, axis=1, keepdims=True)
83+
84+
# dynamic sampling if train
85+
if(self.train):
86+
87+
if(timeseries.shape[1] < self.dynamicLength):
88+
print(timeseries.shape[1], self.dynamicLength)
89+
90+
samplingInit = 0 if timeseries.shape[1] == self.dynamicLength else randrange(timeseries.shape[1] - self.dynamicLength)
91+
timeseries = timeseries[:, samplingInit : samplingInit + self.dynamicLength]
92+
93+
return {"timeseries" : timeseries.astype(np.float32), "label" : label, "subjId" : subjId}
94+
95+
96+
97+
98+
99+
100+

0 commit comments

Comments
 (0)