Skip to content

Commit

Permalink
warn when metric contains NaN (#326)
Browse files Browse the repository at this point in the history
  • Loading branch information
santiviquez authored Oct 3, 2023
1 parent 7dd53d3 commit dd20ef7
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
12 changes: 12 additions & 0 deletions nannyml/performance_calculation/metrics/binary_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score, roc_auc_score
import warnings

from nannyml._typing import ProblemType
from nannyml.base import _list_missing
Expand Down Expand Up @@ -99,6 +100,7 @@ def _calculate(self, data: pd.DataFrame):
y_true, y_pred = _common_data_cleaning(y_true, y_pred)

if y_true.nunique() <= 1:
warnings.warn("Calculated ROC-AUC score contains NaN values.")
return np.nan
else:
return roc_auc_score(y_true, y_pred)
Expand Down Expand Up @@ -167,6 +169,7 @@ def _calculate(self, data: pd.DataFrame):
y_true, y_pred = _common_data_cleaning(y_true, y_pred)

if (y_true.nunique() <= 1) or (y_pred.nunique() <= 1):
warnings.warn("Calculated F1-score contains NaN values.")
return np.nan
else:
return f1_score(y_true, y_pred)
Expand Down Expand Up @@ -234,6 +237,7 @@ def _calculate(self, data: pd.DataFrame):
y_true, y_pred = _common_data_cleaning(y_true, y_pred)

if (y_true.nunique() <= 1) or (y_pred.nunique() <= 1):
warnings.warn("Calculated Precision score contains NaN values.")
return np.nan
else:
return precision_score(y_true, y_pred)
Expand Down Expand Up @@ -301,6 +305,7 @@ def _calculate(self, data: pd.DataFrame):
y_true, y_pred = _common_data_cleaning(y_true, y_pred)

if (y_true.nunique() <= 1) or (y_pred.nunique() <= 1):
warnings.warn("Calculated Recall score contains NaN values.")
return np.nan
else:
return recall_score(y_true, y_pred)
Expand Down Expand Up @@ -373,6 +378,7 @@ def _calculate(self, data: pd.DataFrame):
y_true, y_pred = _common_data_cleaning(y_true, y_pred)

if (y_true.nunique() <= 1) or (y_pred.nunique() <= 1):
warnings.warn("Calculated Specificity score contains NaN values.")
return np.nan
else:
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
Expand Down Expand Up @@ -446,6 +452,7 @@ def _calculate(self, data: pd.DataFrame):
y_true, y_pred = _common_data_cleaning(y_true, y_pred)

if (y_true.nunique() <= 1) or (y_pred.nunique() <= 1):
warnings.warn("Calculated Accuracy score contains NaN values.")
return np.nan
else:
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
Expand Down Expand Up @@ -564,6 +571,7 @@ def _calculate(self, data: pd.DataFrame):
business_value = num_tp * tp_value + num_tn * tn_value + num_fp * fp_value + num_fn * fn_value

if (y_true.nunique() <= 1) or (y_pred.nunique() <= 1):
warnings.warn("Calculated Business Value contains NaN values.")
return np.nan
else:
if self.normalize_business_value is None:
Expand Down Expand Up @@ -745,6 +753,7 @@ def _calculate_true_positives(self, data: pd.DataFrame) -> float:

y_true, y_pred = _common_data_cleaning(y_true, y_pred)
if y_true.empty or y_pred.empty:
warnings.warn("Calculated true_positives contain NaN values.")
return np.nan

num_tp = np.sum(np.logical_and(y_pred, y_true))
Expand Down Expand Up @@ -773,6 +782,7 @@ def _calculate_true_negatives(self, data: pd.DataFrame) -> float:

y_true, y_pred = _common_data_cleaning(y_true, y_pred)
if y_true.empty or y_pred.empty:
warnings.warn("Calculated true_negatives contain NaN values.")
return np.nan

num_tn = np.sum(np.logical_and(np.logical_not(y_pred), np.logical_not(y_true)))
Expand Down Expand Up @@ -801,6 +811,7 @@ def _calculate_false_positives(self, data: pd.DataFrame) -> float:

y_true, y_pred = _common_data_cleaning(y_true, y_pred)
if y_true.empty or y_pred.empty:
warnings.warn("Calculated false_positives contain NaN values.")
return np.nan

num_fp = np.sum(np.logical_and(y_pred, np.logical_not(y_true)))
Expand Down Expand Up @@ -829,6 +840,7 @@ def _calculate_false_negatives(self, data: pd.DataFrame) -> float:

y_true, y_pred = _common_data_cleaning(y_true, y_pred)
if y_true.empty or y_pred.empty:
warnings.warn("Calculated false_negatives contain NaN values.")
return np.nan

num_fn = np.sum(np.logical_and(np.logical_not(y_pred), y_true))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import numpy as np
import pandas as pd
import warnings
from sklearn.metrics import (
accuracy_score,
f1_score,
Expand Down Expand Up @@ -127,6 +128,7 @@ def _calculate(self, data: pd.DataFrame):

y_true, y_pred = _common_data_cleaning(y_true, y_pred)
if y_true.nunique() <= 1:
warnings.warn("Calculated ROC-AUC score contains NaN values.")
return np.nan
else:
return roc_auc_score(y_true, y_pred, multi_class='ovr', average='macro', labels=labels)
Expand Down Expand Up @@ -214,6 +216,7 @@ def _calculate(self, data: pd.DataFrame):

y_true, y_pred = _common_data_cleaning(y_true, y_pred)
if (y_true.nunique() <= 1) or (y_pred.nunique() <= 1):
warnings.warn("Calculated F1-score contains NaN values.")
return np.nan
else:
return f1_score(y_true, y_pred, average='macro', labels=labels)
Expand Down Expand Up @@ -301,6 +304,7 @@ def _calculate(self, data: pd.DataFrame):

y_true, y_pred = _common_data_cleaning(y_true, y_pred)
if (y_true.nunique() <= 1) or (y_pred.nunique() <= 1):
warnings.warn("Calculated Precision score contains NaN values.")
return np.nan
else:
return precision_score(y_true, y_pred, average='macro', labels=labels)
Expand Down Expand Up @@ -388,6 +392,7 @@ def _calculate(self, data: pd.DataFrame):

y_true, y_pred = _common_data_cleaning(y_true, y_pred)
if (y_true.nunique() <= 1) or (y_pred.nunique() <= 1):
warnings.warn("Calculated Recall score contains NaN values.")
return np.nan
else:
return recall_score(y_true, y_pred, average='macro', labels=labels)
Expand Down Expand Up @@ -475,6 +480,7 @@ def _calculate(self, data: pd.DataFrame):

y_true, y_pred = _common_data_cleaning(y_true, y_pred)
if (y_true.nunique() <= 1) or (y_pred.nunique() <= 1):
warnings.warn("Calculated Specificity score contains NaN values.")
return np.nan
else:
MCM = multilabel_confusion_matrix(y_true, y_pred, labels=labels)
Expand Down Expand Up @@ -558,6 +564,7 @@ def _calculate(self, data: pd.DataFrame):

y_true, y_pred = _common_data_cleaning(y_true, y_pred)
if (y_true.nunique() <= 1) or (y_pred.nunique() <= 1):
warnings.warn("Calculated Accuracy score contains NaN values.")
return np.nan
else:
return accuracy_score(y_true, y_pred)
Expand Down

0 comments on commit dd20ef7

Please sign in to comment.