15
15
16
16
from .thresholds import get_optimized_thresholds_df
17
17
18
- def predicted_proba_violin_plot (true_y , predicted_proba , threshold_step = 0.01 , title = "Interactive Probabilities Violin Plot" ):
18
+ def predicted_proba_violin_plot (true_y , predicted_proba , threshold_step = 0.01 , marker_size = 3 , title = "Interactive Probabilities Violin Plot" ):
19
19
20
20
"""
21
21
Plots interactive and customized violin plots of predicted probabilties with plotly,
@@ -37,6 +37,8 @@ def predicted_proba_violin_plot(true_y, predicted_proba, threshold_step = 0.01,
37
37
threshold_step: float, default=0.01
38
38
step between each classification threshold (ranging from 0 to 1) below which prediction label is 0, 1 otherwise
39
39
each value will have a corresponding slider step
40
+ marker_size: int, default=3
41
+ Size of the points to be plotted
40
42
title: str, default='Interactive probabilities Violin Plot'
41
43
The main title of the plot.
42
44
"""
@@ -55,7 +57,7 @@ def predicted_proba_violin_plot(true_y, predicted_proba, threshold_step = 0.01,
55
57
main_title = f"<b>{ title } </b><br>"
56
58
57
59
# VIOLIN PLOT
58
- full_fig = go .Figure (data = go .Violin (y = data_df ['pred' ], x = data_df ['class' ], line_color = 'black ' ,
60
+ full_fig = go .Figure (data = go .Violin (y = data_df ['pred' ], x = data_df ['class' ], line_color = '#0D2A63 ' ,
59
61
meanline_visible = True , points = False , fillcolor = None , opacity = 0.3 , box = None ,
60
62
scalemode = 'count' , showlegend = False ))
61
63
@@ -86,24 +88,26 @@ def predicted_proba_violin_plot(true_y, predicted_proba, threshold_step = 0.01,
86
88
titles [threshold ] = titles [threshold ][:- 3 ] #removes last 3 char (2 spaces and comma)
87
89
88
90
# NOTE: px strip generates n plots, one for each color class (TN, FP, FN, TP) it finds)
89
- strip_points_fig = px .strip (data_df , x = 'class' , y = 'pred' , color = threshold_string ,
90
- color_discrete_map = {'FN' :'red ' , 'FP' :'mediumpurple ' ,
91
- 'TP' :'green ' , 'TN' :'blue ' },
91
+ strip_points_fig = px .strip (data_df , x = 'class' , y = 'pred' , color = threshold_string ,
92
+ color_discrete_map = {'FN' :'#EF71D9 ' , 'FP' :'#EF553B ' ,
93
+ 'TP' :'#00CC96 ' , 'TN' :'#636EFA ' },
92
94
log_y = True , width = 550 , height = 550 , hover_data = [data_df .index ])
93
95
94
- strip_points_fig .update_traces (hovertemplate = 'Idx = %{customdata}<br>Class = %{x}<br>Pred = %{y}' , jitter = 1 , marker_size = 3 )
96
+ strip_points_fig .update_traces (hovertemplate = 'Idx = %{customdata}<br>Class = %{x}<br>Pred = %{y}' , jitter = 1 , marker_size = marker_size )
95
97
96
98
length_fig_list .append (len (strip_points_fig .data ))
97
99
98
100
for i in range (len (strip_points_fig .data )):
99
101
strip_points_fig .data [i ].visible = False
100
- full_fig .add_trace (strip_points_fig .data [i ])
102
+
103
+ full_fig .add_traces (list (strip_points_fig .select_traces ()))
101
104
102
105
full_fig .update_layout (legend_font_size = 9.5 , legend_itemsizing = 'constant' , legend_traceorder = 'grouped' ,
103
106
title = dict (text = main_title + '<span style="font-size: 13px;">' \
104
107
+ titles [threshold_values [0 ]] + '</span>' ,
105
108
y = 0.965 , yanchor = 'bottom' ),
106
109
width = 550 , height = 550 )
110
+
107
111
full_fig .update_layout (margin = dict (l = 40 , r = 40 , t = 60 , b = 40 ))
108
112
109
113
# makes visible the first strip points figure
@@ -201,11 +205,10 @@ def curve_PR_plot(true_y, predicted_proba, beta = 1, title = "Precision Recall C
201
205
"Recall" :recall .tolist (),
202
206
"Precision" :precision .tolist ()})
203
207
204
- pr_fig = px .line (curve_df , x = "Recall" , y = "Precision" , hover_data = ["Thresholds" ], title = main_title )
208
+ full_fig = px .line (curve_df , x = "Recall" , y = "Precision" , hover_data = ["Thresholds" ], title = main_title )
205
209
206
- pr_fig .update_traces (hovertemplate = 'Threshold: %{customdata:.4f} <br>Precision: %{y:.4f} <br>Recall: %{x:.4f}<extra></extra>' )
207
- pr_fig .update_traces (line_color = '#222A2A' , line_width = 2 , textposition = "top center" )
208
- full_fig = pr_fig
210
+ full_fig .update_traces (hovertemplate = 'Threshold: %{customdata:.4f} <br>Precision: %{y:.4f} <br>Recall: %{x:.4f}<extra></extra>' )
211
+ full_fig .update_traces (textposition = "top center" )
209
212
210
213
f_scores = np .linspace (0.2 , 0.8 , num = 4 )
211
214
@@ -222,18 +225,19 @@ def curve_PR_plot(true_y, predicted_proba, beta = 1, title = "Precision Recall C
222
225
223
226
iso_fig = px .line (recall_precision_df , x = "recall" , y = "precision" )
224
227
iso_fig .update_traces (hovertemplate = []) # no hover info displayed but keeps dashed lines
225
- iso_fig .update_traces (line_color = '#778AAE ' , line = dict (dash = 'dot' ), line_width = 0.3 )
228
+ iso_fig .update_traces (line_color = '#4C78A8 ' , line = dict (dash = 'dot' ), line_width = 0.8 )
226
229
227
230
full_fig .add_annotation (x = 0.90 , y = y [45 ] + 0.01 , text = "f" + str (beta ) + "={0:0.1f}" .format (f_score ),
228
231
showarrow = False ,yshift = 10 )
229
- full_fig = go .Figure (data = full_fig .data + iso_fig .data , layout = full_fig .layout )
232
+
233
+ full_fig .add_traces (list (iso_fig .select_traces ()))
230
234
231
235
area_under_pr_curve = auc (recall , precision )
232
236
233
237
full_fig .update_xaxes (range = [0.0 , 1.0 ],title_text = 'Recall' )
234
238
full_fig .update_yaxes (range = [0.0 , 1.05 ],title_text = 'Precision' )
235
239
236
- full_fig .add_shape (type = 'line' , line = dict (dash = 'dash' ),x0 = 0 , x1 = 1 , y0 = baseline , y1 = baseline )
240
+ full_fig .add_shape (type = 'line' , line = dict (dash = 'dash' , color = '#20313e' ),x0 = 0 , x1 = 1 , y0 = baseline , y1 = baseline )
237
241
238
242
full_fig ['data' ][0 ]['showlegend' ]= True
239
243
full_fig ['data' ][1 ]['showlegend' ]= True
@@ -244,8 +248,8 @@ def curve_PR_plot(true_y, predicted_proba, beta = 1, title = "Precision Recall C
244
248
legend_font_size = 9.5 ,
245
249
width = 550 , height = 550 )
246
250
247
- full_fig .update_xaxes (showspikes = True )
248
- full_fig .update_yaxes (showspikes = True )
251
+ full_fig .update_xaxes (showspikes = True , spikedash = 'dot' , spikethickness = 2 )
252
+ full_fig .update_yaxes (showspikes = True , spikedash = 'dot' , spikethickness = 2 )
249
253
full_fig .update_layout (margin = dict (l = 40 , r = 40 , t = 40 , b = 40 ))
250
254
full_fig .show ()
251
255
@@ -292,22 +296,22 @@ def curve_ROC_plot(true_y, predicted_proba, title = "Receiver Operating Characte
292
296
hover_data = ["Thresholds" ],
293
297
width = 550 , height = 550 )
294
298
295
- fig .update_traces (line_color = "#222A2A" , line_width = 2 , textposition = "top center" )
299
+ fig .update_traces (textposition = "top center" )
296
300
fig .update_traces (hovertemplate = 'Threshold: %{customdata:.4f} <br>False Positive Rate: %{x:.4f} <br>True Positive Rate: %{y:.4f}<extra></extra>' )
297
301
298
- fig .add_shape (type = "line" , line = dict (dash = "dash" ),
302
+ fig .add_shape (type = "line" , line = dict (dash = "dash" , color = '#20313e' ),
299
303
x0 = 0 , x1 = 1 , y0 = 0 , y1 = 1 )
300
304
301
305
area_under_ROC_curve = auc (fpr , tpr )
302
306
303
307
fig ["data" ][0 ]["name" ]= f"ROC Curve (AUC={ area_under_ROC_curve :.3f} )"
304
-
305
308
fig ["data" ][0 ]["showlegend" ]= True
309
+
306
310
fig .update_layout (legend = dict (yanchor = "top" , y = 0.20 , xanchor = "left" , x = 0.5 ),
307
311
legend_font_size = 9.5 )
308
312
309
- fig .update_xaxes (showspikes = True )
310
- fig .update_yaxes (showspikes = True )
313
+ fig .update_xaxes (showspikes = True , spikedash = 'dot' , spikethickness = 2 )
314
+ fig .update_yaxes (showspikes = True , spikedash = 'dot' , spikethickness = 2 )
311
315
312
316
fig .update_yaxes (scaleanchor = "x" , scaleratio = 1 )
313
317
fig .update_xaxes (range = [0 ,1 ], constrain = "domain" )
0 commit comments