16
16
from .utilities import _get_amount_matrix , _get_cost_matrix , _get_density_curve_data
17
17
from .utilities import get_amount_cost_df , get_invariant_metrics_df , get_confusion_matrix_and_metrics_df , get_gain_curve_data , get_expected_calibration_error
18
18
19
- from .thresholds import get_optimized_thresholds_df
19
+ from .thresholds import _get_subset_optimal_thresholds_df
20
20
21
21
def calibration_curve_plot (true_y , predicted_proba ,
22
22
n_bins = 10 ,
@@ -657,6 +657,8 @@ def curve_PR_plot(true_y, predicted_proba, beta = 1, title = "Precision Recall C
657
657
if beta < 0 :
658
658
raise ValueError ("beta should be >=0 in the F-beta score" )
659
659
660
+ true_y = np .array (true_y ) #convert to array
661
+
660
662
precision , recall , thresholds = precision_recall_curve (true_y , predicted_proba )
661
663
662
664
listTr = thresholds .tolist ()
@@ -668,7 +670,7 @@ def curve_PR_plot(true_y, predicted_proba, beta = 1, title = "Precision Recall C
668
670
F_beta_score = fbeta_score (true_y , y_pred , beta = beta , zero_division = 0 )
669
671
listFbeta .append (F_beta_score )
670
672
671
- listTr .append (None )
673
+ listTr .append ('-' )
672
674
listFbeta .append (0 )
673
675
674
676
area_under_pr_curve = auc (recall , precision )
@@ -691,7 +693,7 @@ def curve_PR_plot(true_y, predicted_proba, beta = 1, title = "Precision Recall C
691
693
f_scores = np .linspace (0.2 , 0.8 , num = 4 )
692
694
693
695
for f_score in f_scores :
694
- x = np .linspace (0.01 , 1 )
696
+ x = np .linspace (0.01 , 1.01 )
695
697
y = f_score * x / (x + beta * beta * (x - f_score ))
696
698
X = x [y >= 0 ]
697
699
listX = X .tolist ()
@@ -842,8 +844,7 @@ def predicted_proba_violin_plot(true_y, predicted_proba, threshold_step = 0.01,
842
844
except :
843
845
n_of_decimals = 4
844
846
845
- threshold_values = list (np .arange (0 , 1 + threshold_step , threshold_step ))
846
-
847
+ threshold_values = list (np .round (np .arange (0 , 1 + threshold_step , threshold_step ), n_of_decimals )) #define thresholds list
847
848
main_title = f"<b>{ title } </b><br>"
848
849
849
850
# VIOLIN PLOT
@@ -859,7 +860,7 @@ def predicted_proba_violin_plot(true_y, predicted_proba, threshold_step = 0.01,
859
860
# STRIP PLOT
860
861
for threshold in threshold_values :
861
862
862
- threshold_string = "thresh_" + str (round ( threshold , n_of_decimals ) )
863
+ threshold_string = "thresh_" + str (threshold )
863
864
864
865
865
866
conditions = [(data_df ['class' ] == 0 ) & (data_df ['pred' ] < threshold ),
@@ -922,7 +923,7 @@ def predicted_proba_violin_plot(true_y, predicted_proba, threshold_step = 0.01,
922
923
# + subtitle + titles[threshold] + '</span>',
923
924
# y = 0.965, yanchor = 'bottom')}
924
925
{"value" : "set " }],
925
- label = str (round ( threshold_values [i ], n_of_decimals ) ),
926
+ label = str (threshold_values [i ]),
926
927
)
927
928
928
929
n_of_strip_plots = length_fig_list [i ]
@@ -994,7 +995,7 @@ def predicted_proba_density_curve_plot(true_y, predicted_proba,
994
995
except :
995
996
n_of_decimals = 4
996
997
997
- threshold_values = list (np .arange (0 , 1 + threshold_step , threshold_step )) #define thresholds array
998
+ threshold_values = list (np .round ( np . arange (0 , 1 + threshold_step , threshold_step ), n_of_decimals )) #define thresholds list
998
999
main_title = f"<b>{ title } </b><br>"
999
1000
1000
1001
# get density curve data
@@ -1155,7 +1156,7 @@ def predicted_proba_density_curve_plot(true_y, predicted_proba,
1155
1156
y = 0.965 , yanchor = 'bottom' ),
1156
1157
"annotations" : annotation_list }
1157
1158
],
1158
- label = str (round ( threshold , n_of_decimals ) )
1159
+ label = str (threshold )
1159
1160
)
1160
1161
1161
1162
step ["args" ][0 ]["visible" ][j ] = True # TN density curve
@@ -1192,28 +1193,26 @@ def predicted_proba_density_curve_plot(true_y, predicted_proba,
1192
1193
return fig
1193
1194
1194
1195
def confusion_matrix_plot (true_y , predicted_proba , threshold_step = 0.01 ,
1195
- amounts = None , cost_dict = None , optimize_threshold = None ,
1196
- N_subsets = 70 , subsets_size = 0.2 , with_replacement = False ,
1197
- currency = '€' , random_state = None ,
1196
+ amounts = None , cost_dict = None ,
1197
+ currency = '€' ,
1198
1198
title = 'Interactive Confusion Matrix' ):
1199
1199
1200
1200
"""
1201
1201
Returns plotly figure of interactive and customized confusion matrix,
1202
1202
one for each threshold that can be selected with a slider,
1203
- displaying additional information (metrics, optimized thresholds).
1203
+ displaying additional information (metrics, optimal thresholds).
1204
1204
1205
1205
Returns three dataframes containing:
1206
1206
- metrics that depend on threshold
1207
1207
- metrics that don't depend on threshold,
1208
- - optimized thresholds (or empty)
1208
+ - optimal thresholds: threshold values that, for this subset, maximize (or minimize) the related metric value
1209
1209
1210
1210
Plot is constituted by:
1211
1211
- table displaying metrics that vary based on the threshold selected:
1212
1212
Accuracy, Balanced Acc., F1, Precision, Recall, MCC, Cohen's K
1213
1213
- table displaying metrics that don't depend on threshold:
1214
1214
ROC auc, Pecision-Recall auc, Brier score
1215
- - when optimize_threshold is given:
1216
- table displayng thresholds optimized using GHOST method for any of the following metrics:
1215
+ - table displayng best thresholds for the following metrics:
1217
1216
Kohen's Kappa, Matthew's Correlation Coefficient, ROC, F-beta scores (beta = 1, 0.5, 2)
1218
1217
and for minimal total cost
1219
1218
- confusion matrix (annotated heatmap) that varies based on the threshold selected
@@ -1237,25 +1236,10 @@ def confusion_matrix_plot(true_y, predicted_proba, threshold_step = 0.01,
1237
1236
dict containing costs associated to each class (TN, FP, FN, TP)
1238
1237
with keys "TN", "FP", "FN", "TP"
1239
1238
and values that can be both lists (with coherent lenghts) and/or floats
1240
- (output from get_cost_dict)
1241
- necessary when optimizing threshold for minimal total costs
1242
- optimize_threshold: {'all', 'ROC', 'MCC', 'Kappa', 'Fscore', 'Cost'}
1243
- or list containing allowed values except 'all', default=None
1244
- metrics for which thresholds will be optimized
1245
- 'all' is equvalent to ['ROC', 'MCC', 'Kappa', 'Fscore'] if cost_dict=None, ['ROC', 'MCC', 'Kappa', 'Fscore', 'Cost'] otherwise
1246
- N_subsets: int, default=70
1247
- Number of subsets used in GHOST optimization process
1248
- subsets_size: float or int, default=0.2
1249
- Size of the subsets used in GHOST optimization process.
1250
- If float, represents the proportion of the dataset to include in the subsets.
1251
- If integer, it represents the actual number of instances to include in the subsets.
1252
- with_replacement: bool, default=False
1253
- If True, the subsets used in GHOST optimization process are drawn randomly with replacement, without otherwise.
1239
+ (output from get_cost_dict)
1254
1240
currency: str, default='€'
1255
1241
currency symbol to be visualized. For unusual currencies, you can use their HTML code representation
1256
1242
(eg. Indian rupee: '₹')
1257
- random_state: int, default=None
1258
- Controls the randomness of the bootstrapping of the samples when optimizing thresholds with GHOST method
1259
1243
title: str, default='Interactive Confusion Matrix'
1260
1244
The main title of the plot.
1261
1245
@@ -1274,7 +1258,7 @@ def confusion_matrix_plot(true_y, predicted_proba, threshold_step = 0.01,
1274
1258
except :
1275
1259
n_of_decimals = 4
1276
1260
1277
- threshold_values = list (np .arange (0 , 1 + threshold_step , threshold_step )) #define thresholds array
1261
+ threshold_values = list (np .round ( np . arange (0 , 1 + threshold_step , threshold_step ), n_of_decimals )) #define thresholds list
1278
1262
n_data = len (true_y )
1279
1263
main_title = f"<b>{ title } </b><br>"
1280
1264
subtitle = "Total obs: " + '{:,}' .format (n_data )
@@ -1284,6 +1268,12 @@ def confusion_matrix_plot(true_y, predicted_proba, threshold_step = 0.01,
1284
1268
tot_amount = sum (amounts )
1285
1269
subtitle += "<br>Total amount: " + currency + '{:,.2f}' .format (tot_amount )
1286
1270
1271
+ if cost_dict is not None :
1272
+ cost_TN = []
1273
+ cost_FP = []
1274
+ cost_FN = []
1275
+ cost_TP = []
1276
+
1287
1277
# initialize annotation matrix
1288
1278
annotations_fixed = np .array ([[["TN" , "True Negative" ], ["FP" , "False Positive" ]],
1289
1279
[["FN" , "False Negative" ], ["TP" , "True Positive" ]]])
@@ -1302,26 +1292,6 @@ def confusion_matrix_plot(true_y, predicted_proba, threshold_step = 0.01,
1302
1292
cells = dict (values = [constant_metrics_df ['invariant_metric' ], constant_metrics_df ['value' ]])
1303
1293
), row = 1 , col = 2 )
1304
1294
1305
- # create table with optimized thresholds or empty:
1306
- if optimize_threshold is not None :
1307
-
1308
- # compute optimized thresholds and create dataframe
1309
- optimal_thresholds_df = get_optimized_thresholds_df (optimize_threshold = optimize_threshold ,
1310
- threshold_values = threshold_values [1 :- 1 ],
1311
- true_y = true_y ,
1312
- predicted_proba = predicted_proba ,
1313
- cost_dict = cost_dict ,
1314
- N_subsets = N_subsets , subsets_size = subsets_size ,
1315
- with_replacement = with_replacement ,
1316
- random_state = random_state )
1317
- fig .add_trace (
1318
- go .Table (header = dict (values = ['Optimized Metric' , 'Optimal Threshold' ]),
1319
- cells = dict (values = [optimal_thresholds_df ['optimized_metric' ], optimal_thresholds_df ['optimal_threshold' ]])
1320
- ), row = 1 , col = 3 )
1321
- else :
1322
- optimal_thresholds_df = None # needed for return statement
1323
- fig .add_trace (go .Table ({}), row = 1 , col = 3 )
1324
-
1325
1295
# create dynamic titles dictionary (will be empty if cost is not given)
1326
1296
titles = {}
1327
1297
@@ -1356,6 +1326,11 @@ def confusion_matrix_plot(true_y, predicted_proba, threshold_step = 0.01,
1356
1326
1357
1327
if cost_dict :
1358
1328
cost_matrix = _get_cost_matrix (true_y , predicted_proba , threshold , cost_dict )
1329
+ cost_TN .append (cost_matrix [0 ,0 ])
1330
+ cost_FP .append (cost_matrix [0 ,1 ])
1331
+ cost_FN .append (cost_matrix [1 ,0 ])
1332
+ cost_TP .append (cost_matrix [1 ,1 ])
1333
+
1359
1334
total_cost = cost_matrix .sum ()
1360
1335
annotations = np .dstack ((annotations , cost_matrix , cost_matrix / total_cost )) # add cost matrix and perc. matrix
1361
1336
annotations_max_index += 2
@@ -1381,7 +1356,7 @@ def confusion_matrix_plot(true_y, predicted_proba, threshold_step = 0.01,
1381
1356
fig .add_trace (go .Heatmap (z = matrix ,
1382
1357
text = annotations ,
1383
1358
texttemplate = "<b>%{text[0]}</b><br>" + template ,
1384
- name = "threshold: " + str (round ( threshold , n_of_decimals ) ),
1359
+ name = "threshold: " + str (threshold ),
1385
1360
hovertemplate = "<b>%{text[1]}</b><br>Count: " + template ,
1386
1361
x = ['False' , 'True' ],
1387
1362
y = ['True' , 'False' ],
@@ -1392,15 +1367,37 @@ def confusion_matrix_plot(true_y, predicted_proba, threshold_step = 0.01,
1392
1367
# pivot metrics_dep_on_threshold_df
1393
1368
name_col = metrics_dep_on_threshold_df .columns [0 ]
1394
1369
value_col = metrics_dep_on_threshold_df .columns [1 ]
1395
- metrics_dep_on_threshold_df = metrics_dep_on_threshold_df .pivot (columns = name_col , values = value_col , index = 'threshold' ).reset_index ('threshold' ).rename_axis (None , axis = 1 )
1370
+ metrics_dep_on_threshold_df = metrics_dep_on_threshold_df .pivot (columns = name_col , values = value_col , index = 'threshold' ).reset_index ('threshold' ).rename_axis (None , axis = 1 )
1396
1371
1397
- # fig.data[0] is the constant metrcis table, fig.data[1] is the optimal threshold table, always visible
1398
- fig .data [2 ].visible = True # first variable metrics table
1399
- fig .data [3 ].visible = True # first confusion matrix
1372
+ # create table with optimal thresholds
1373
+ # compute optimal thresholds and create dataframe
1374
+ if cost_dict is not None :
1375
+ cost_per_threshold_df = pd .DataFrame (zip (threshold_values ,
1376
+ cost_TN , cost_FP , cost_FN , cost_TP ),
1377
+ columns = ['threshold' ,
1378
+ 'cost_TN' , 'cost_FP' , 'cost_FN' , 'cost_TP' ]).sort_values (by = 'threshold' )
1379
+
1380
+ cost_per_threshold_df ['total_cost' ] = cost_per_threshold_df [['cost_TN' , 'cost_FP' ,
1381
+ 'cost_FN' , 'cost_TP' ]].apply (sum , axis = 1 )
1382
+
1383
+ else :
1384
+ cost_per_threshold_df = None
1385
+
1386
+ optimal_thresholds_df = _get_subset_optimal_thresholds_df (metrics_dep_on_threshold_df , cost_per_threshold_df )
1387
+
1388
+ fig .add_trace (
1389
+ go .Table (header = dict (values = ['Metric' , 'Optimal Threshold' ]),
1390
+ cells = dict (values = [optimal_thresholds_df ['metric' ], optimal_thresholds_df ['optimal_threshold' ]])
1391
+ ), row = 1 , col = 3 )
1392
+
1393
+ # fig.data[0] is the constant metrcis table,always visible
1394
+ fig .data [1 ].visible = True # first variable metrics table
1395
+ fig .data [2 ].visible = True # first confusion matrix
1396
+ # fig.data[-1] (last obkect) is the optimal thresholds table, always visible
1400
1397
1401
1398
# create and add slider
1402
1399
steps = []
1403
- j = 2 # skip first and second trace (invariant metric table, opt. thresholds/empty table)
1400
+ j = 1 # skip first trace (invariant metrics table)
1404
1401
1405
1402
for threshold in threshold_values :
1406
1403
step = dict (method = "update" ,
@@ -1409,13 +1406,13 @@ def confusion_matrix_plot(true_y, predicted_proba, threshold_step = 0.01,
1409
1406
+ subtitle + titles [threshold ] + '</span>' ,
1410
1407
y = 0.965 , yanchor = 'bottom' )}
1411
1408
],
1412
- label = str (round ( threshold , n_of_decimals ) )
1409
+ label = str (threshold )
1413
1410
)
1414
1411
1415
1412
step ["args" ][0 ]["visible" ][0 ] = True # constant metric table always visible
1416
- step ["args" ][0 ]["visible" ][1 ] = True # opt. thresholds/empty table always visible
1417
1413
step ["args" ][0 ]["visible" ][j ] = True # threshold related confusion matrix
1418
1414
step ["args" ][0 ]["visible" ][j + 1 ] = True # threshold related variable metrics table
1415
+ step ["args" ][0 ]["visible" ][- 1 ] = True # opt. thresholds/empty table always visible
1419
1416
steps .append (step )
1420
1417
j += 2 # add 2 to trace index (confusion matrix and variable metrics table)
1421
1418
@@ -1496,7 +1493,7 @@ def confusion_linechart_plot(true_y, predicted_proba, threshold_step = 0.01,
1496
1493
except :
1497
1494
n_of_decimals = 4
1498
1495
1499
- threshold_values = list (np .arange (0 , 1 + threshold_step , threshold_step ))
1496
+ threshold_values = list (np .round ( np . arange (0 , 1 + threshold_step , threshold_step ), n_of_decimals )) #define thresholds list
1500
1497
middle_x = (threshold_values [0 ] + threshold_values [- 1 ])/ 2
1501
1498
n_data = len (true_y )
1502
1499
main_title = f"<b>{ title } </b><br>"
@@ -1593,7 +1590,7 @@ def confusion_linechart_plot(true_y, predicted_proba, threshold_step = 0.01,
1593
1590
1594
1591
if x_intersect :
1595
1592
intercepts_str = 'Swaps: '
1596
- intercepts_str += ", " .join (str (round ( x , n_of_decimals ) ) for x in x_intersect )
1593
+ intercepts_str += ", " .join (str (x ) for x in x_intersect )
1597
1594
fig .add_annotation (xref = "x domain" ,yref = "y domain" ,x = 0.5 , y = 1.15 , showarrow = False ,
1598
1595
text = intercepts_str , row = row_index , col = col_index )
1599
1596
@@ -1731,7 +1728,7 @@ def confusion_linechart_plot(true_y, predicted_proba, threshold_step = 0.01,
1731
1728
textposition = textposition ,
1732
1729
hoverlabel = dict (bgcolor = 'rgb(68, 68, 68)' ),
1733
1730
hovertemplate = '%{x:.' + str (n_of_decimals ) + 'f}<extra></extra>' ,
1734
- name = str (round ( threshold , n_of_decimals ) ),
1731
+ name = str (threshold ),
1735
1732
marker = dict (color = color ),
1736
1733
marker_size = 8 ,
1737
1734
visible = False ),
@@ -1758,7 +1755,7 @@ def confusion_linechart_plot(true_y, predicted_proba, threshold_step = 0.01,
1758
1755
+ subtitle + titles [threshold ] + '</span>' ,
1759
1756
y = 0.965 , yanchor = 'bottom' )}
1760
1757
],
1761
- label = str (round ( threshold , n_of_decimals ) )
1758
+ label = str (threshold )
1762
1759
)
1763
1760
step ["args" ][0 ]["visible" ][:static_charts_num ] = [True ]* static_charts_num # line charts
1764
1761
step ["args" ][0 ]["visible" ][j :j + markers_num ] = [True ]* markers_num # line chart markers
@@ -1860,7 +1857,7 @@ def total_amount_cost_plot(true_y, predicted_proba, threshold_step = 0.01,
1860
1857
except :
1861
1858
n_of_decimals = 4
1862
1859
1863
- threshold_values = list (np .arange (0 , 1 + threshold_step , threshold_step ))
1860
+ threshold_values = list (np .round ( np . arange (0 , 1 + threshold_step , threshold_step ), n_of_decimals )) #define thresholds list
1864
1861
middle_x = (threshold_values [0 ] + threshold_values [- 1 ])/ 2
1865
1862
1866
1863
supported_label = ["TN" , "FP" , "FN" , "TP" ]
@@ -1955,7 +1952,7 @@ def total_amount_cost_plot(true_y, predicted_proba, threshold_step = 0.01,
1955
1952
1956
1953
if x_intersect :
1957
1954
intercepts_str = 'Swaps at thresholds: '
1958
- intercepts_str += ", " .join (str (round ( x , n_of_decimals ) ) for x in x_intersect )
1955
+ intercepts_str += ", " .join (str (x ) for x in x_intersect )
1959
1956
1960
1957
fig .update_layout (title = dict (text = f"<b>{ title } </b><span style='font-size: 13px;'><br>" + subtitle + \
1961
1958
'<br>' + intercepts_str ,
0 commit comments