diff --git a/nannyml/drift/multivariate/domain_classifier/calculator.py b/nannyml/drift/multivariate/domain_classifier/calculator.py index 39dbbc01..dbd04ebb 100644 --- a/nannyml/drift/multivariate/domain_classifier/calculator.py +++ b/nannyml/drift/multivariate/domain_classifier/calculator.py @@ -15,7 +15,7 @@ """ import warnings -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Union import numpy as np import pandas as pd @@ -27,7 +27,7 @@ from sklearn.preprocessing import OrdinalEncoder from nannyml.base import AbstractCalculator, _list_missing, _split_features_by_type -from nannyml.chunk import Chunker +from nannyml.chunk import Chunker, Chunk from nannyml.drift.multivariate.domain_classifier.result import Result from nannyml.exceptions import InvalidArgumentsException @@ -200,6 +200,7 @@ def __init__( # # sampling error # self._sampling_error_components: Tuple = () self.result: Optional[Result] = None + self._is_fitted: bool = False @log_usage(UsageEvent.DC_CALC_FIT) def _fit(self, reference_data: pd.DataFrame, *args, **kwargs): @@ -225,11 +226,28 @@ def _fit(self, reference_data: pd.DataFrame, *args, **kwargs): if column_name not in self.categorical_column_names: self.categorical_column_names.append(column_name) - self._reference_X = reference_data[self.feature_column_names] + # Get timestamp column from chunker incase the calculator is initialized with a chunker without directly + # been provided the timestamp column name. + # + # The reference data will be sorted according to the timestamp column (when available) to mimic + # Chunker behavior. This means the reference data will be "aligned" with chunked reference data. + # This way we can use chunk indices on the internal reference data copy. + if self.chunker.timestamp_column_name: + if self.chunker.timestamp_column_name not in list(reference_data.columns): + raise InvalidArgumentsException( + f"timestamp column '{self.chunker.timestamp_column_name}' not in columns: {list(reference_data.columns)}." # noqa: E501 + ) + self._reference_X = reference_data.sort_values( + by=[self.chunker.timestamp_column_name] + ).reset_index(drop=True)[self.feature_column_names] + else: + self._reference_X = reference_data[self.feature_column_names] self.result = self._calculate(data=reference_data) self.result.data[('chunk', 'period')] = 'reference' + self._is_fitted = True + return self @log_usage(UsageEvent.DC_CALC_RUN) @@ -252,7 +270,7 @@ def _calculate(self, data: pd.DataFrame, *args, **kwargs) -> Result: 'end_date': chunk.end_datetime, 'period': 'analysis', # 'sampling_error': sampling_error(self._sampling_error_components, chunk.data), - 'classifier_auroc_value': self._calculate_chunk(data=chunk.data), + 'classifier_auroc_value': self._calculate_chunk(chunk=chunk), } for chunk in chunks ] @@ -262,7 +280,7 @@ def _calculate(self, data: pd.DataFrame, *args, **kwargs) -> Result: res.columns = multilevel_index res = res.reset_index(drop=True) - if self.result is None: + if not self._is_fitted: self._set_metric_thresholds(res) res = self._populate_alert_thresholds(res) self.result = Result( @@ -274,20 +292,26 @@ def _calculate(self, data: pd.DataFrame, *args, **kwargs) -> Result: ) else: res = self._populate_alert_thresholds(res) - self.result = self.result.filter(period='reference') + self.result = self.result.filter(period='reference') # type: ignore self.result.data = pd.concat([self.result.data, res], ignore_index=True) return self.result - def _calculate_chunk(self, data: pd.DataFrame): - - chunk_X = data[self.feature_column_names] - reference_X = self._reference_X - chunk_y = np.ones(len(chunk_X)) - reference_y = np.zeros(len(reference_X)) - X = pd.concat([reference_X, chunk_X], ignore_index=True) - y = np.concatenate([reference_y, chunk_y]) - - X, y = drop_matching_duplicate_rows(X, y, self.feature_column_names) + def _calculate_chunk(self, chunk: Chunk): + if self._is_fitted: + chunk_X = chunk.data[self.feature_column_names] + reference_X = self._reference_X + chunk_y = np.ones(len(chunk_X)) + reference_y = np.zeros(len(reference_X)) + X = pd.concat([reference_X, chunk_X], ignore_index=True) + y = np.concatenate([reference_y, chunk_y]) + else: + # Use information from chunk indices to identify reference chunk's location. This is possible because + # both the internal reference data copy and the chunk data were sorted by timestamp, so these + # indices align. This way we eliminate the need to combine these two data frames and drop duplicate rows, + # which is a costly operation. + X = self._reference_X + y = np.zeros(len(X)) + y[chunk.start_index : chunk.end_index + 1] = 1 df_X_transformed = preprocess_categorical_features( X, self.continuous_column_names, self.categorical_column_names @@ -351,6 +375,7 @@ def _populate_alert_thresholds(self, result_data: pd.DataFrame) -> pd.DataFrame: return result_data def tune_hyperparams(self, X: pd.DataFrame, y: np.ndarray): + """Train an LGBM model while also performing hyperparameter tuning.""" with warnings.catch_warnings(): # Ingore lightgbm's UserWarning: Using categorical_feature in Dataset. # We explicitly use that feature, don't spam the user @@ -366,18 +391,10 @@ def tune_hyperparams(self, X: pd.DataFrame, y: np.ndarray): self.hyperparameters = {**automl.model.estimator.get_params()} -def drop_matching_duplicate_rows(X: pd.DataFrame, y: np.ndarray, subset: List[str]) -> Tuple[pd.DataFrame, np.ndarray]: - X['__target__'] = y - X = X.drop_duplicates(subset=subset, keep='last').reset_index(drop=True) - y = X['__target__'] - X.drop('__target__', axis=1, inplace=True) - - return X, y - - def preprocess_categorical_features( X: pd.DataFrame, continuous_column_names: List[str], categorical_column_names: List[str] ) -> pd.DataFrame: + """Preprodess categorical features.""" X_cont = X[continuous_column_names] enc = OrdinalEncoder(handle_unknown='use_encoded_value', unknown_value=-1) diff --git a/tests/drift/test_multiv_dc.py b/tests/drift/test_multiv_dc.py index 42bb0e57..b9c59084 100644 --- a/tests/drift/test_multiv_dc.py +++ b/tests/drift/test_multiv_dc.py @@ -49,3 +49,26 @@ def test_default_cdd_run(binary_classification_data): 0.9136, ] assert list(results.to_df().loc[:, ("domain_classifier_auroc", "alert")]) == [False, False, False, True, True] + + +def test_cdd_run_w_timestamp(binary_classification_data): + """Test a default run of DC.""" + ( + reference, + analysis, + ) = binary_classification_data + calc = DomainClassifierCalculator( + feature_column_names=column_names1, + chunk_size=5_000, + timestamp_column_name='timestamp' + ) + calc.fit(reference.sample(frac=1).reset_index(drop=True)) + results = calc.calculate(analysis) + assert list(results.to_df().loc[:, ("domain_classifier_auroc", "value")].round(4)) == [ + 0.5020, + 0.5002, + 0.5174, + 0.9108, + 0.9136, + ] + assert list(results.to_df().loc[:, ("domain_classifier_auroc", "alert")]) == [False, False, False, True, True]