Skip to content

Commit 7be4853

Browse files
committed
changed colours of plots
changed colours of plots
1 parent 291db62 commit 7be4853

7 files changed

+54205
-54272
lines changed

bctools/plots.py

+25-21
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from .thresholds import get_optimized_thresholds_df
1717

18-
def predicted_proba_violin_plot(true_y, predicted_proba, threshold_step = 0.01, title = "Interactive Probabilities Violin Plot"):
18+
def predicted_proba_violin_plot(true_y, predicted_proba, threshold_step = 0.01, marker_size = 3, title = "Interactive Probabilities Violin Plot"):
1919

2020
"""
2121
Plots interactive and customized violin plots of predicted probabilties with plotly,
@@ -37,6 +37,8 @@ def predicted_proba_violin_plot(true_y, predicted_proba, threshold_step = 0.01,
3737
threshold_step: float, default=0.01
3838
step between each classification threshold (ranging from 0 to 1) below which prediction label is 0, 1 otherwise
3939
each value will have a corresponding slider step
40+
marker_size: int, default=3
41+
Size of the points to be plotted
4042
title: str, default='Interactive probabilities Violin Plot'
4143
The main title of the plot.
4244
"""
@@ -55,7 +57,7 @@ def predicted_proba_violin_plot(true_y, predicted_proba, threshold_step = 0.01,
5557
main_title = f"<b>{title}</b><br>"
5658

5759
# VIOLIN PLOT
58-
full_fig=go.Figure(data=go.Violin(y=data_df['pred'], x=data_df['class'], line_color='black',
60+
full_fig=go.Figure(data=go.Violin(y=data_df['pred'], x=data_df['class'], line_color='#0D2A63',
5961
meanline_visible=True, points=False, fillcolor=None, opacity=0.3, box=None,
6062
scalemode='count', showlegend = False))
6163

@@ -86,24 +88,26 @@ def predicted_proba_violin_plot(true_y, predicted_proba, threshold_step = 0.01,
8688
titles[threshold] = titles[threshold][:-3] #removes last 3 char (2 spaces and comma)
8789

8890
# NOTE: px strip generates n plots, one for each color class (TN, FP, FN, TP) it finds)
89-
strip_points_fig = px.strip(data_df, x='class', y='pred', color= threshold_string,
90-
color_discrete_map = {'FN':'red', 'FP':'mediumpurple',
91-
'TP':'green', 'TN':'blue'},
91+
strip_points_fig = px.strip(data_df, x='class', y='pred', color=threshold_string,
92+
color_discrete_map = {'FN':'#EF71D9', 'FP':'#EF553B',
93+
'TP':'#00CC96', 'TN':'#636EFA'},
9294
log_y=True, width=550, height=550, hover_data = [data_df.index])
9395

94-
strip_points_fig.update_traces(hovertemplate = 'Idx = %{customdata}<br>Class = %{x}<br>Pred = %{y}', jitter = 1, marker_size=3)
96+
strip_points_fig.update_traces(hovertemplate = 'Idx = %{customdata}<br>Class = %{x}<br>Pred = %{y}', jitter = 1, marker_size=marker_size)
9597

9698
length_fig_list.append(len(strip_points_fig.data))
9799

98100
for i in range(len(strip_points_fig.data)):
99101
strip_points_fig.data[i].visible=False
100-
full_fig.add_trace(strip_points_fig.data[i])
102+
103+
full_fig.add_traces(list(strip_points_fig.select_traces()))
101104

102105
full_fig.update_layout(legend_font_size=9.5, legend_itemsizing='constant', legend_traceorder='grouped',
103106
title=dict(text = main_title + '<span style="font-size: 13px;">' \
104107
+ titles[threshold_values[0]] + '</span>',
105108
y = 0.965, yanchor = 'bottom'),
106109
width=550, height=550)
110+
107111
full_fig.update_layout(margin=dict(l=40, r=40, t=60, b=40))
108112

109113
# makes visible the first strip points figure
@@ -201,11 +205,10 @@ def curve_PR_plot(true_y, predicted_proba, beta = 1, title = "Precision Recall C
201205
"Recall":recall.tolist(),
202206
"Precision":precision.tolist()})
203207

204-
pr_fig=px.line(curve_df, x="Recall", y="Precision", hover_data=["Thresholds"], title=main_title)
208+
full_fig=px.line(curve_df, x="Recall", y="Precision", hover_data=["Thresholds"], title=main_title)
205209

206-
pr_fig.update_traces(hovertemplate='Threshold: %{customdata:.4f} <br>Precision: %{y:.4f} <br>Recall: %{x:.4f}<extra></extra>')
207-
pr_fig.update_traces(line_color='#222A2A', line_width=2, textposition="top center")
208-
full_fig = pr_fig
210+
full_fig.update_traces(hovertemplate='Threshold: %{customdata:.4f} <br>Precision: %{y:.4f} <br>Recall: %{x:.4f}<extra></extra>')
211+
full_fig.update_traces(textposition="top center")
209212

210213
f_scores = np.linspace(0.2, 0.8, num=4)
211214

@@ -222,18 +225,19 @@ def curve_PR_plot(true_y, predicted_proba, beta = 1, title = "Precision Recall C
222225

223226
iso_fig=px.line(recall_precision_df, x="recall", y="precision")
224227
iso_fig.update_traces(hovertemplate=[]) # no hover info displayed but keeps dashed lines
225-
iso_fig.update_traces(line_color='#778AAE', line=dict(dash='dot'), line_width=0.3)
228+
iso_fig.update_traces(line_color='#4C78A8', line=dict(dash='dot'), line_width=0.8)
226229

227230
full_fig.add_annotation(x=0.90, y=y[45] + 0.01, text="f"+ str(beta) + "={0:0.1f}".format(f_score),
228231
showarrow=False,yshift=10)
229-
full_fig=go.Figure(data = full_fig.data + iso_fig.data, layout = full_fig.layout)
232+
233+
full_fig.add_traces(list(iso_fig.select_traces()))
230234

231235
area_under_pr_curve = auc(recall, precision)
232236

233237
full_fig.update_xaxes(range=[0.0, 1.0],title_text='Recall')
234238
full_fig.update_yaxes(range=[0.0, 1.05],title_text='Precision')
235239

236-
full_fig.add_shape(type='line', line=dict(dash='dash'),x0=0, x1=1, y0=baseline, y1=baseline)
240+
full_fig.add_shape(type='line', line=dict(dash='dash', color = '#20313e'),x0=0, x1=1, y0=baseline, y1=baseline)
237241

238242
full_fig['data'][0]['showlegend']= True
239243
full_fig['data'][1]['showlegend']= True
@@ -244,8 +248,8 @@ def curve_PR_plot(true_y, predicted_proba, beta = 1, title = "Precision Recall C
244248
legend_font_size=9.5,
245249
width=550, height=550)
246250

247-
full_fig.update_xaxes(showspikes=True)
248-
full_fig.update_yaxes(showspikes=True)
251+
full_fig.update_xaxes(showspikes=True, spikedash = 'dot', spikethickness=2)
252+
full_fig.update_yaxes(showspikes=True, spikedash = 'dot', spikethickness=2)
249253
full_fig.update_layout(margin=dict(l=40, r=40, t=40, b=40))
250254
full_fig.show()
251255

@@ -292,22 +296,22 @@ def curve_ROC_plot(true_y, predicted_proba, title = "Receiver Operating Characte
292296
hover_data=["Thresholds"],
293297
width=550, height=550)
294298

295-
fig.update_traces(line_color="#222A2A", line_width=2, textposition="top center")
299+
fig.update_traces(textposition="top center")
296300
fig.update_traces(hovertemplate='Threshold: %{customdata:.4f} <br>False Positive Rate: %{x:.4f} <br>True Positive Rate: %{y:.4f}<extra></extra>')
297301

298-
fig.add_shape(type="line", line=dict(dash="dash"),
302+
fig.add_shape(type="line", line=dict(dash="dash", color = '#20313e'),
299303
x0=0, x1=1, y0=0, y1=1)
300304

301305
area_under_ROC_curve = auc(fpr, tpr)
302306

303307
fig["data"][0]["name"]= f"ROC Curve (AUC={area_under_ROC_curve:.3f})"
304-
305308
fig["data"][0]["showlegend"]= True
309+
306310
fig.update_layout(legend = dict(yanchor="top", y=0.20, xanchor="left", x=0.5),
307311
legend_font_size=9.5)
308312

309-
fig.update_xaxes(showspikes=True)
310-
fig.update_yaxes(showspikes=True)
313+
fig.update_xaxes(showspikes=True, spikedash = 'dot', spikethickness=2)
314+
fig.update_yaxes(showspikes=True, spikedash = 'dot', spikethickness=2)
311315

312316
fig.update_yaxes(scaleanchor="x", scaleratio=1)
313317
fig.update_xaxes(range=[0,1], constrain="domain")

example-notebook/example_classification_model.ipynb

-54,250
This file was deleted.

0 commit comments

Comments
 (0)