Skip to content

Commit 1d55f97

Browse files
committed
changes in cofusion matrix plot
- bug fixed - optimized thresholds not computed with GHOST method anymore, but refer to the thresholds related to the best value of the considered metric for the given set of data
1 parent eb2aa8f commit 1d55f97

File tree

1 file changed

+64
-67
lines changed

1 file changed

+64
-67
lines changed

bctools/plots.py

+64-67
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from .utilities import _get_amount_matrix, _get_cost_matrix, _get_density_curve_data
1717
from .utilities import get_amount_cost_df, get_invariant_metrics_df, get_confusion_matrix_and_metrics_df, get_gain_curve_data, get_expected_calibration_error
1818

19-
from .thresholds import get_optimized_thresholds_df
19+
from .thresholds import _get_subset_optimal_thresholds_df
2020

2121
def calibration_curve_plot(true_y, predicted_proba,
2222
n_bins = 10,
@@ -657,6 +657,8 @@ def curve_PR_plot(true_y, predicted_proba, beta = 1, title = "Precision Recall C
657657
if beta < 0:
658658
raise ValueError("beta should be >=0 in the F-beta score")
659659

660+
true_y = np.array(true_y) #convert to array
661+
660662
precision, recall, thresholds = precision_recall_curve(true_y, predicted_proba)
661663

662664
listTr = thresholds.tolist()
@@ -668,7 +670,7 @@ def curve_PR_plot(true_y, predicted_proba, beta = 1, title = "Precision Recall C
668670
F_beta_score = fbeta_score(true_y, y_pred, beta=beta, zero_division=0)
669671
listFbeta.append(F_beta_score)
670672

671-
listTr.append(None)
673+
listTr.append('-')
672674
listFbeta.append(0)
673675

674676
area_under_pr_curve = auc(recall, precision)
@@ -691,7 +693,7 @@ def curve_PR_plot(true_y, predicted_proba, beta = 1, title = "Precision Recall C
691693
f_scores = np.linspace(0.2, 0.8, num=4)
692694

693695
for f_score in f_scores:
694-
x = np.linspace(0.01, 1)
696+
x = np.linspace(0.01, 1.01)
695697
y = f_score * x / (x + beta * beta * (x - f_score))
696698
X = x[y >= 0]
697699
listX = X.tolist()
@@ -842,8 +844,7 @@ def predicted_proba_violin_plot(true_y, predicted_proba, threshold_step = 0.01,
842844
except:
843845
n_of_decimals = 4
844846

845-
threshold_values = list(np.arange(0, 1 + threshold_step, threshold_step))
846-
847+
threshold_values = list(np.round(np.arange(0, 1 + threshold_step, threshold_step), n_of_decimals)) #define thresholds list
847848
main_title = f"<b>{title}</b><br>"
848849

849850
# VIOLIN PLOT
@@ -859,7 +860,7 @@ def predicted_proba_violin_plot(true_y, predicted_proba, threshold_step = 0.01,
859860
# STRIP PLOT
860861
for threshold in threshold_values:
861862

862-
threshold_string = "thresh_" + str(round(threshold, n_of_decimals))
863+
threshold_string = "thresh_" + str(threshold)
863864

864865

865866
conditions = [(data_df['class'] == 0) & (data_df['pred'] < threshold),
@@ -922,7 +923,7 @@ def predicted_proba_violin_plot(true_y, predicted_proba, threshold_step = 0.01,
922923
# + subtitle + titles[threshold] + '</span>',
923924
# y = 0.965, yanchor = 'bottom')}
924925
{"value": "set "}],
925-
label = str(round(threshold_values[i], n_of_decimals)),
926+
label = str(threshold_values[i]),
926927
)
927928

928929
n_of_strip_plots = length_fig_list[i]
@@ -994,7 +995,7 @@ def predicted_proba_density_curve_plot(true_y, predicted_proba,
994995
except:
995996
n_of_decimals = 4
996997

997-
threshold_values = list(np.arange(0, 1+threshold_step, threshold_step)) #define thresholds array
998+
threshold_values = list(np.round(np.arange(0, 1 + threshold_step, threshold_step), n_of_decimals)) #define thresholds list
998999
main_title = f"<b>{title}</b><br>"
9991000

10001001
# get density curve data
@@ -1155,7 +1156,7 @@ def predicted_proba_density_curve_plot(true_y, predicted_proba,
11551156
y = 0.965, yanchor = 'bottom'),
11561157
"annotations": annotation_list}
11571158
],
1158-
label = str(round(threshold, n_of_decimals))
1159+
label = str(threshold)
11591160
)
11601161

11611162
step["args"][0]["visible"][j] = True # TN density curve
@@ -1192,28 +1193,26 @@ def predicted_proba_density_curve_plot(true_y, predicted_proba,
11921193
return fig
11931194

11941195
def confusion_matrix_plot(true_y, predicted_proba, threshold_step = 0.01,
1195-
amounts = None, cost_dict = None, optimize_threshold = None,
1196-
N_subsets = 70, subsets_size = 0.2, with_replacement = False,
1197-
currency = '€', random_state = None,
1196+
amounts = None, cost_dict = None,
1197+
currency = '€',
11981198
title = 'Interactive Confusion Matrix'):
11991199

12001200
"""
12011201
Returns plotly figure of interactive and customized confusion matrix,
12021202
one for each threshold that can be selected with a slider,
1203-
displaying additional information (metrics, optimized thresholds).
1203+
displaying additional information (metrics, optimal thresholds).
12041204
12051205
Returns three dataframes containing:
12061206
- metrics that depend on threshold
12071207
- metrics that don't depend on threshold,
1208-
- optimized thresholds (or empty)
1208+
- optimal thresholds: threshold values that, for this subset, maximize (or minimize) the related metric value
12091209
12101210
Plot is constituted by:
12111211
- table displaying metrics that vary based on the threshold selected:
12121212
Accuracy, Balanced Acc., F1, Precision, Recall, MCC, Cohen's K
12131213
- table displaying metrics that don't depend on threshold:
12141214
ROC auc, Pecision-Recall auc, Brier score
1215-
- when optimize_threshold is given:
1216-
table displayng thresholds optimized using GHOST method for any of the following metrics:
1215+
- table displayng best thresholds for the following metrics:
12171216
Kohen's Kappa, Matthew's Correlation Coefficient, ROC, F-beta scores (beta = 1, 0.5, 2)
12181217
and for minimal total cost
12191218
- confusion matrix (annotated heatmap) that varies based on the threshold selected
@@ -1237,25 +1236,10 @@ def confusion_matrix_plot(true_y, predicted_proba, threshold_step = 0.01,
12371236
dict containing costs associated to each class (TN, FP, FN, TP)
12381237
with keys "TN", "FP", "FN", "TP"
12391238
and values that can be both lists (with coherent lenghts) and/or floats
1240-
(output from get_cost_dict)
1241-
necessary when optimizing threshold for minimal total costs
1242-
optimize_threshold: {'all', 'ROC', 'MCC', 'Kappa', 'Fscore', 'Cost'}
1243-
or list containing allowed values except 'all', default=None
1244-
metrics for which thresholds will be optimized
1245-
'all' is equvalent to ['ROC', 'MCC', 'Kappa', 'Fscore'] if cost_dict=None, ['ROC', 'MCC', 'Kappa', 'Fscore', 'Cost'] otherwise
1246-
N_subsets: int, default=70
1247-
Number of subsets used in GHOST optimization process
1248-
subsets_size: float or int, default=0.2
1249-
Size of the subsets used in GHOST optimization process.
1250-
If float, represents the proportion of the dataset to include in the subsets.
1251-
If integer, it represents the actual number of instances to include in the subsets.
1252-
with_replacement: bool, default=False
1253-
If True, the subsets used in GHOST optimization process are drawn randomly with replacement, without otherwise.
1239+
(output from get_cost_dict)
12541240
currency: str, default='€'
12551241
currency symbol to be visualized. For unusual currencies, you can use their HTML code representation
12561242
(eg. Indian rupee: '&#8377;')
1257-
random_state: int, default=None
1258-
Controls the randomness of the bootstrapping of the samples when optimizing thresholds with GHOST method
12591243
title: str, default='Interactive Confusion Matrix'
12601244
The main title of the plot.
12611245
@@ -1274,7 +1258,7 @@ def confusion_matrix_plot(true_y, predicted_proba, threshold_step = 0.01,
12741258
except:
12751259
n_of_decimals = 4
12761260

1277-
threshold_values = list(np.arange(0, 1 + threshold_step, threshold_step)) #define thresholds array
1261+
threshold_values = list(np.round(np.arange(0, 1 + threshold_step, threshold_step), n_of_decimals)) #define thresholds list
12781262
n_data = len(true_y)
12791263
main_title = f"<b>{title}</b><br>"
12801264
subtitle = "Total obs: " + '{:,}'.format(n_data)
@@ -1284,6 +1268,12 @@ def confusion_matrix_plot(true_y, predicted_proba, threshold_step = 0.01,
12841268
tot_amount = sum(amounts)
12851269
subtitle += "<br>Total amount: " + currency + '{:,.2f}'.format(tot_amount)
12861270

1271+
if cost_dict is not None:
1272+
cost_TN = []
1273+
cost_FP = []
1274+
cost_FN = []
1275+
cost_TP = []
1276+
12871277
# initialize annotation matrix
12881278
annotations_fixed = np.array([[["TN", "True Negative"], ["FP", "False Positive"]],
12891279
[["FN", "False Negative"], ["TP", "True Positive"]]])
@@ -1302,26 +1292,6 @@ def confusion_matrix_plot(true_y, predicted_proba, threshold_step = 0.01,
13021292
cells=dict(values=[constant_metrics_df['invariant_metric'], constant_metrics_df['value']])
13031293
), row=1, col=2)
13041294

1305-
# create table with optimized thresholds or empty:
1306-
if optimize_threshold is not None:
1307-
1308-
# compute optimized thresholds and create dataframe
1309-
optimal_thresholds_df = get_optimized_thresholds_df(optimize_threshold = optimize_threshold,
1310-
threshold_values = threshold_values[1:-1],
1311-
true_y = true_y,
1312-
predicted_proba = predicted_proba,
1313-
cost_dict = cost_dict,
1314-
N_subsets = N_subsets, subsets_size = subsets_size,
1315-
with_replacement = with_replacement,
1316-
random_state = random_state)
1317-
fig.add_trace(
1318-
go.Table(header=dict(values=['Optimized Metric', 'Optimal Threshold']),
1319-
cells=dict(values=[optimal_thresholds_df['optimized_metric'], optimal_thresholds_df['optimal_threshold']])
1320-
), row=1, col=3)
1321-
else:
1322-
optimal_thresholds_df = None # needed for return statement
1323-
fig.add_trace(go.Table({}), row=1, col=3)
1324-
13251295
# create dynamic titles dictionary (will be empty if cost is not given)
13261296
titles = {}
13271297

@@ -1356,6 +1326,11 @@ def confusion_matrix_plot(true_y, predicted_proba, threshold_step = 0.01,
13561326

13571327
if cost_dict:
13581328
cost_matrix = _get_cost_matrix(true_y, predicted_proba, threshold, cost_dict)
1329+
cost_TN.append(cost_matrix[0,0])
1330+
cost_FP.append(cost_matrix[0,1])
1331+
cost_FN.append(cost_matrix[1,0])
1332+
cost_TP.append(cost_matrix[1,1])
1333+
13591334
total_cost = cost_matrix.sum()
13601335
annotations = np.dstack((annotations, cost_matrix, cost_matrix/total_cost)) # add cost matrix and perc. matrix
13611336
annotations_max_index += 2
@@ -1381,7 +1356,7 @@ def confusion_matrix_plot(true_y, predicted_proba, threshold_step = 0.01,
13811356
fig.add_trace(go.Heatmap(z = matrix,
13821357
text = annotations,
13831358
texttemplate= "<b>%{text[0]}</b><br>" + template,
1384-
name="threshold: " + str(round(threshold, n_of_decimals)),
1359+
name="threshold: " + str(threshold),
13851360
hovertemplate = "<b>%{text[1]}</b><br>Count: " + template,
13861361
x=['False', 'True'],
13871362
y=['True', 'False'],
@@ -1392,15 +1367,37 @@ def confusion_matrix_plot(true_y, predicted_proba, threshold_step = 0.01,
13921367
# pivot metrics_dep_on_threshold_df
13931368
name_col = metrics_dep_on_threshold_df.columns[0]
13941369
value_col = metrics_dep_on_threshold_df.columns[1]
1395-
metrics_dep_on_threshold_df = metrics_dep_on_threshold_df.pivot(columns = name_col, values = value_col, index = 'threshold').reset_index('threshold').rename_axis(None, axis=1)
1370+
metrics_dep_on_threshold_df = metrics_dep_on_threshold_df.pivot(columns = name_col, values = value_col, index = 'threshold').reset_index('threshold').rename_axis(None, axis=1)
13961371

1397-
# fig.data[0] is the constant metrcis table, fig.data[1] is the optimal threshold table, always visible
1398-
fig.data[2].visible = True # first variable metrics table
1399-
fig.data[3].visible = True # first confusion matrix
1372+
# create table with optimal thresholds
1373+
# compute optimal thresholds and create dataframe
1374+
if cost_dict is not None:
1375+
cost_per_threshold_df = pd.DataFrame(zip(threshold_values,
1376+
cost_TN, cost_FP, cost_FN, cost_TP),
1377+
columns = ['threshold',
1378+
'cost_TN', 'cost_FP', 'cost_FN', 'cost_TP']).sort_values(by='threshold')
1379+
1380+
cost_per_threshold_df['total_cost'] = cost_per_threshold_df[['cost_TN', 'cost_FP',
1381+
'cost_FN', 'cost_TP']].apply(sum, axis = 1)
1382+
1383+
else:
1384+
cost_per_threshold_df = None
1385+
1386+
optimal_thresholds_df = _get_subset_optimal_thresholds_df(metrics_dep_on_threshold_df, cost_per_threshold_df)
1387+
1388+
fig.add_trace(
1389+
go.Table(header=dict(values=['Metric', 'Optimal Threshold']),
1390+
cells=dict(values=[optimal_thresholds_df['metric'], optimal_thresholds_df['optimal_threshold']])
1391+
), row=1, col=3)
1392+
1393+
# fig.data[0] is the constant metrcis table,always visible
1394+
fig.data[1].visible = True # first variable metrics table
1395+
fig.data[2].visible = True # first confusion matrix
1396+
# fig.data[-1] (last obkect) is the optimal thresholds table, always visible
14001397

14011398
# create and add slider
14021399
steps = []
1403-
j = 2 # skip first and second trace (invariant metric table, opt. thresholds/empty table)
1400+
j = 1 # skip first trace (invariant metrics table)
14041401

14051402
for threshold in threshold_values:
14061403
step = dict(method="update",
@@ -1409,13 +1406,13 @@ def confusion_matrix_plot(true_y, predicted_proba, threshold_step = 0.01,
14091406
+ subtitle + titles[threshold] + '</span>',
14101407
y = 0.965, yanchor = 'bottom')}
14111408
],
1412-
label = str(round(threshold, n_of_decimals))
1409+
label = str(threshold)
14131410
)
14141411

14151412
step["args"][0]["visible"][0] = True # constant metric table always visible
1416-
step["args"][0]["visible"][1] = True # opt. thresholds/empty table always visible
14171413
step["args"][0]["visible"][j] = True # threshold related confusion matrix
14181414
step["args"][0]["visible"][j+1] = True # threshold related variable metrics table
1415+
step["args"][0]["visible"][-1] = True # opt. thresholds/empty table always visible
14191416
steps.append(step)
14201417
j += 2 # add 2 to trace index (confusion matrix and variable metrics table)
14211418

@@ -1496,7 +1493,7 @@ def confusion_linechart_plot(true_y, predicted_proba, threshold_step = 0.01,
14961493
except:
14971494
n_of_decimals = 4
14981495

1499-
threshold_values = list(np.arange(0, 1 + threshold_step, threshold_step))
1496+
threshold_values = list(np.round(np.arange(0, 1 + threshold_step, threshold_step), n_of_decimals)) #define thresholds list
15001497
middle_x = (threshold_values[0] + threshold_values[-1])/2
15011498
n_data = len(true_y)
15021499
main_title = f"<b>{title}</b><br>"
@@ -1593,7 +1590,7 @@ def confusion_linechart_plot(true_y, predicted_proba, threshold_step = 0.01,
15931590

15941591
if x_intersect:
15951592
intercepts_str = 'Swaps: '
1596-
intercepts_str += ", ".join(str(round(x, n_of_decimals)) for x in x_intersect)
1593+
intercepts_str += ", ".join(str(x) for x in x_intersect)
15971594
fig.add_annotation(xref="x domain",yref="y domain",x=0.5, y=1.15, showarrow=False,
15981595
text=intercepts_str, row=row_index, col=col_index)
15991596

@@ -1731,7 +1728,7 @@ def confusion_linechart_plot(true_y, predicted_proba, threshold_step = 0.01,
17311728
textposition = textposition,
17321729
hoverlabel = dict(bgcolor = 'rgb(68, 68, 68)'),
17331730
hovertemplate = '%{x:.' + str(n_of_decimals) + 'f}<extra></extra>',
1734-
name = str(round(threshold, n_of_decimals)),
1731+
name = str(threshold),
17351732
marker=dict(color=color),
17361733
marker_size = 8,
17371734
visible=False),
@@ -1758,7 +1755,7 @@ def confusion_linechart_plot(true_y, predicted_proba, threshold_step = 0.01,
17581755
+ subtitle + titles[threshold] + '</span>',
17591756
y = 0.965, yanchor = 'bottom')}
17601757
],
1761-
label = str(round(threshold, n_of_decimals))
1758+
label = str(threshold)
17621759
)
17631760
step["args"][0]["visible"][:static_charts_num] = [True]*static_charts_num # line charts
17641761
step["args"][0]["visible"][j:j+markers_num] = [True]*markers_num # line chart markers
@@ -1860,7 +1857,7 @@ def total_amount_cost_plot(true_y, predicted_proba, threshold_step = 0.01,
18601857
except:
18611858
n_of_decimals = 4
18621859

1863-
threshold_values = list(np.arange(0, 1 + threshold_step, threshold_step))
1860+
threshold_values = list(np.round(np.arange(0, 1 + threshold_step, threshold_step), n_of_decimals)) #define thresholds list
18641861
middle_x = (threshold_values[0] + threshold_values[-1])/2
18651862

18661863
supported_label = ["TN", "FP", "FN", "TP"]
@@ -1955,7 +1952,7 @@ def total_amount_cost_plot(true_y, predicted_proba, threshold_step = 0.01,
19551952

19561953
if x_intersect:
19571954
intercepts_str = 'Swaps at thresholds: '
1958-
intercepts_str += ", ".join(str(round(x, n_of_decimals)) for x in x_intersect)
1955+
intercepts_str += ", ".join(str(x) for x in x_intersect)
19591956

19601957
fig.update_layout(title = dict(text = f"<b>{title}</b><span style='font-size: 13px;'><br>" + subtitle + \
19611958
'<br>' + intercepts_str,

0 commit comments

Comments
 (0)