-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathconf_models.py
96 lines (90 loc) · 3.46 KB
/
conf_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis, QuadraticDiscriminantAnalysis
from sklearn.dummy import DummyClassifier
from sklearn.ensemble import AdaBoostClassifier, RandomForestClassifier, GradientBoostingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import GaussianNB
from sklearn.neighbors import KNeighborsClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
def get_classification_models(random_seed):
return [
(
'AdaBoostClassifier()', # Adaptive Boosting (AdaBoost)
AdaBoostClassifier(base_estimator=DecisionTreeClassifier(max_depth=3, random_state=random_seed), random_state=random_seed),
),
(
'DecisionTreeClassifier()', # Decision Tree (DT)
DecisionTreeClassifier(random_state=random_seed),
),
(
'DummyClassifier(strategy=\'most_frequent\')',
DummyClassifier(strategy='most_frequent', random_state=random_seed),
),
(
'DummyClassifier(strategy=\'prior\')',
DummyClassifier(strategy='prior', random_state=random_seed),
),
(
'DummyClassifier(strategy=\'stratified\')',
DummyClassifier(strategy='stratified', random_state=random_seed),
),
(
'DummyClassifier(strategy=\'uniform\')',
DummyClassifier(strategy='uniform', random_state=random_seed),
),
(
'GaussianNB()', # Naive Bayes (NB)
GaussianNB(),
),
(
'GradientBoostingClassifier()', # Gradient Boosting Machine (GBM)
GradientBoostingClassifier(random_state=random_seed),
),
(
'KNeighborsClassifier()', # k-NearestNeighbour (KNN)
KNeighborsClassifier(),
),
(
'LinearDiscriminantAnalysis()', # Linear Discriminant Analysis (LDA)
LinearDiscriminantAnalysis(),
),
(
'LogisticRegression()', # Logistic Regression
LogisticRegression(max_iter=100000, random_state=random_seed),
),
(
'MLPClassifier()', # Multi-Layer Perceptron Neural Network
MLPClassifier(max_iter=100000, random_state=random_seed),
),
(
'QuadraticDiscriminantAnalysis()', # Quadratic Discriminant Analysis (QDA)
QuadraticDiscriminantAnalysis(),
),
(
'RandomForestClassifier()', # Random Forest (RF)
RandomForestClassifier(random_state=random_seed),
),
(
'SVC(kernel=\'linear\')', # Support Vector Machine with linear kernel (LinearSVM)
SVC(kernel='linear', max_iter=100000, probability=True, random_state=random_seed),
),
(
'SVC(kernel=\'rbf\')', # Support Vector Machine with radial kernel (RadialSVM)
SVC(kernel='rbf', max_iter=100000, probability=True, random_state=random_seed),
),
]
def get_group_importance_models(random_seed):
return [
(
'RandomForestClassifier()', # Random Forest (RF)
RandomForestClassifier(random_state=random_seed),
),
]
def get_feature_importance_models(random_seed):
return [
(
'RandomForestClassifier()', # Random Forest (RF)
RandomForestClassifier(random_state=random_seed),
),
]