diff --git a/nannyml/__init__.py b/nannyml/__init__.py
index ddb2ef7b5..bc295b7f6 100644
--- a/nannyml/__init__.py
+++ b/nannyml/__init__.py
@@ -50,6 +50,7 @@
load_titanic_dataset,
load_us_census_ma_employment_data,
)
+from .distribution import CategoricalDistributionCalculator, ContinuousDistributionCalculator
from .drift import AlertCountRanker, CorrelationRanker, DataReconstructionDriftCalculator, UnivariateDriftCalculator
from .exceptions import ChunkerException, InvalidArgumentsException, MissingMetadataException
from .io import DatabaseWriter, PickleFileWriter, RawFilesWriter
diff --git a/nannyml/data_quality/unseen/calculator.py b/nannyml/data_quality/unseen/calculator.py
index cb1d41ec1..3ae17873c 100644
--- a/nannyml/data_quality/unseen/calculator.py
+++ b/nannyml/data_quality/unseen/calculator.py
@@ -219,6 +219,7 @@ def _calculate(self, data: pd.DataFrame, *args, **kwargs) -> Result:
# Applicable here but to many of the base classes as well (e.g. fitting and calculating)
self.result = self.result.filter(period='reference')
self.result.data = pd.concat([self.result.data, res]).reset_index(drop=True)
+ self.result.data.sort_index(inplace=True)
return self.result
diff --git a/nannyml/distribution/__init__.py b/nannyml/distribution/__init__.py
new file mode 100644
index 000000000..8d3290202
--- /dev/null
+++ b/nannyml/distribution/__init__.py
@@ -0,0 +1,2 @@
+from .categorical import CategoricalDistributionCalculator
+from .continuous import ContinuousDistributionCalculator
diff --git a/nannyml/distribution/categorical/__init__.py b/nannyml/distribution/categorical/__init__.py
new file mode 100644
index 000000000..95ebbcdbc
--- /dev/null
+++ b/nannyml/distribution/categorical/__init__.py
@@ -0,0 +1 @@
+from .calculator import CategoricalDistributionCalculator
diff --git a/nannyml/distribution/categorical/calculator.py b/nannyml/distribution/categorical/calculator.py
new file mode 100644
index 000000000..6725ee843
--- /dev/null
+++ b/nannyml/distribution/categorical/calculator.py
@@ -0,0 +1,140 @@
+from typing import List, Optional, Union
+
+import numpy as np
+import pandas as pd
+from typing_extensions import Self
+
+from nannyml import Chunker
+from nannyml.base import AbstractCalculator, _list_missing
+from nannyml.distribution.categorical.result import Result
+from nannyml.exceptions import InvalidArgumentsException
+
+
+class CategoricalDistributionCalculator(AbstractCalculator):
+ def __init__(
+ self,
+ column_names: Union[str, List[str]],
+ timestamp_column_name: Optional[str] = None,
+ chunk_size: Optional[int] = None,
+ chunk_number: Optional[int] = None,
+ chunk_period: Optional[str] = None,
+ chunker: Optional[Chunker] = None,
+ ):
+ super().__init__(
+ chunk_size,
+ chunk_number,
+ chunk_period,
+ chunker,
+ timestamp_column_name,
+ )
+
+ self.column_names = column_names if isinstance(column_names, List) else [column_names]
+ self.result: Optional[Result] = None
+ self._was_fitted: bool = False
+
+ def _fit(self, reference_data: pd.DataFrame, *args, **kwargs) -> Self:
+ self.result = self._calculate(reference_data)
+ self._was_fitted = True
+
+ return self
+
+ def _calculate(self, data: pd.DataFrame, *args, **kwargs) -> Result:
+ if data.empty:
+ raise InvalidArgumentsException('data contains no rows. Please provide a valid data set.')
+
+ _list_missing(self.column_names, data)
+
+ # result_data = pd.DataFrame(columns=_create_multilevel_index(self.column_names))
+ result_data = pd.DataFrame()
+
+ chunks = self.chunker.split(data)
+ chunks_data = pd.DataFrame(
+ {
+ 'key': [c.key for c in chunks],
+ 'chunk_index': [c.chunk_index for c in chunks],
+ 'start_datetime': [c.start_datetime for c in chunks],
+ 'end_datetime': [c.end_datetime for c in chunks],
+ 'start_index': [c.start_index for c in chunks],
+ 'end_index': [c.end_index for c in chunks],
+ 'period': ['analysis' if self._was_fitted else 'reference' for _ in chunks],
+ }
+ )
+
+ for column in self.column_names:
+ value_counts = calculate_value_counts(
+ data=data[column],
+ chunker=self.chunker,
+ timestamps=data.get(self.timestamp_column_name, default=None),
+ max_number_of_categories=5,
+ missing_category_label='Missing',
+ column_name=column,
+ )
+ result_data = pd.concat([result_data, pd.merge(chunks_data, value_counts, on='chunk_index')])
+
+ # result_data.index = pd.MultiIndex.from_tuples(list(zip(result_data['column_name'], result_data['value'])))
+
+ if self.result is None:
+ self.result = Result(result_data, self.column_names, self.timestamp_column_name, self.chunker)
+ else:
+ # self.result = self.result.data.loc[self.result.data['period'] == 'reference', :]
+ self.result.data = pd.concat([self.result.data, result_data]).reset_index(drop=True)
+
+ return self.result
+
+
+def calculate_value_counts(
+ data: Union[np.ndarray, pd.Series],
+ chunker: Chunker,
+ missing_category_label,
+ max_number_of_categories,
+ timestamps: Optional[Union[np.ndarray, pd.Series]] = None,
+ column_name: Optional[str] = None,
+):
+ if isinstance(data, np.ndarray):
+ if column_name is None:
+ raise InvalidArgumentsException("'column_name' can not be None when 'data' is of type 'np.ndarray'.")
+ data = pd.Series(data, name=column_name)
+ else:
+ column_name = data.name
+
+ data = data.astype("category")
+ cat_str = [str(value) for value in data.cat.categories.values]
+ data = data.cat.rename_categories(cat_str)
+ data = data.cat.add_categories([missing_category_label, 'Other'])
+ data = data.fillna(missing_category_label)
+
+ if max_number_of_categories:
+ top_categories = data.value_counts().index.tolist()[:max_number_of_categories]
+ if data.nunique() > max_number_of_categories + 1:
+ data.loc[~data.isin(top_categories)] = 'Other'
+
+ data = data.cat.remove_unused_categories()
+
+ categories_ordered = data.value_counts().index.tolist()
+ categorical_data = pd.Categorical(data, categories_ordered)
+
+ # TODO: deal with None timestamps
+ if isinstance(timestamps, pd.Series):
+ timestamps = timestamps.reset_index()
+
+ chunks = chunker.split(pd.concat([pd.Series(categorical_data, name=column_name), timestamps], axis=1))
+ data_with_chunk_keys = pd.concat([chunk.data.assign(chunk_index=chunk.chunk_index) for chunk in chunks])
+
+ value_counts_table = (
+ data_with_chunk_keys.groupby(['chunk_index'])[column_name]
+ .value_counts()
+ .to_frame('value_counts')
+ .sort_values(by=['chunk_index', 'value_counts'])
+ .reset_index()
+ .rename(columns={column_name: 'value'})
+ .assign(column_name=column_name)
+ )
+
+ value_counts_table['value_counts_total'] = value_counts_table['chunk_index'].map(
+ value_counts_table.groupby('chunk_index')['value_counts'].sum()
+ )
+ value_counts_table['value_counts_normalised'] = (
+ value_counts_table['value_counts'] / value_counts_table['value_counts_total']
+ )
+
+ return value_counts_table
diff --git a/nannyml/distribution/categorical/result.py b/nannyml/distribution/categorical/result.py
new file mode 100644
index 000000000..1f74aa3e7
--- /dev/null
+++ b/nannyml/distribution/categorical/result.py
@@ -0,0 +1,440 @@
+import copy
+import math
+from typing import Any, Dict, List, Optional, Union
+
+import numpy as np
+import pandas as pd
+import plotly.graph_objs as go
+from typing_extensions import Self
+
+from nannyml import Chunker
+from nannyml._typing import Key
+from nannyml.base import AbstractResult
+from nannyml.drift.univariate.result import Result as DriftResult
+from nannyml.exceptions import InvalidArgumentsException
+from nannyml.plots import Colors, Figure, is_time_based_x_axis
+from nannyml.plots.components.stacked_bar_plot import alert as stacked_bar_alert
+from nannyml.plots.components.stacked_bar_plot import stacked_bar
+
+
+class Result(AbstractResult):
+ def __init__(
+ self,
+ results_data: pd.DataFrame,
+ column_names: List[str],
+ timestamp_column_name: Optional[str],
+ chunker: Chunker,
+ ):
+ super().__init__(results_data, column_names)
+
+ self.timestamp_column_name = timestamp_column_name
+ self.chunker = chunker
+ self.column_names = column_names
+
+ def to_df(self, multilevel: bool = True) -> pd.DataFrame:
+ return self.data
+
+ def _filter(
+ self,
+ period: str,
+ metrics: Optional[List[str]] = None,
+ column_names: Optional[Union[str, List[str]]] = None,
+ *args,
+ **kwargs,
+ ) -> Self:
+ data = self.data
+ if period != 'all':
+ data = data.loc[data['period'] == period, :]
+ data = data.reset_index(drop=True)
+
+ if isinstance(column_names, str):
+ column_names = [column_names]
+ if column_names:
+ data = data.loc[data['column_name'].isin(column_names), :]
+
+ res = copy.deepcopy(self)
+ res.data = data
+ return res
+
+ @property
+ def chunk_keys(self) -> pd.Series:
+ return self.data['key']
+
+ @property
+ def chunk_start_dates(self) -> pd.Series:
+ return self.data['start_datetime']
+
+ # def chunk_start_dates_for_key(self, key: Key) -> Optional[pd.Series]:
+ # return self._get_property_for_key(key, 'start_datetime')
+
+ @property
+ def chunk_end_dates(self) -> pd.Series:
+ return self.data['end_datetime']
+
+ # def chunk_end_dates_for_key(self, key: Key) -> Optional[pd.Series]:
+ # return self._get_property_for_key(key, 'end_datetime')
+
+ @property
+ def chunk_start_indices(self) -> pd.Series:
+ return self.data['start_index']
+
+ # def chunk_start_indices_for_key(self, key: Key) -> Optional[pd.Series]:
+ # return self._get_property_for_key(key, 'start_index')
+
+ @property
+ def chunk_end_indices(self) -> pd.Series:
+ return self.data['end_index']
+
+ # def chunk_end_indices_for_key(self, key: Key) -> Optional[pd.Series]:
+ # return self._get_property_for_key(key, 'end_index')
+
+ @property
+ def chunk_indices(self) -> pd.Series:
+ return self.data['chunk_index']
+
+ # def chunk_indices_for_key(self, key: Key) -> Optional[pd.Series]:
+ # return self._get_property_for_key(key, 'chunk_index')
+
+ @property
+ def chunk_periods(self) -> pd.Series:
+ return self.data['period']
+
+ # def chunk_periods_for_key(self, key: Key) -> Optional[pd.Series]:
+ # return self._get_property_for_key(key, 'period')
+
+ def value_counts(self, key: Optional[Key] = None, column_name: Optional[str] = None) -> pd.DataFrame:
+ if not key and not column_name:
+ raise InvalidArgumentsException(
+ "cannot retrieve value counts when key and column_name are both not set. "
+ "Please provide either a key or a column."
+ )
+
+ if key:
+ (column_name,) = key.properties
+
+ data = self.filter(column_names=[column_name]).data
+ res = data[
+ [
+ 'value',
+ 'key',
+ 'start_datetime',
+ 'end_datetime',
+ 'start_index',
+ 'end_index',
+ 'chunk_index',
+ 'value_counts',
+ 'value_counts_total',
+ 'value_counts_normalised',
+ ]
+ ].rename(
+ columns={'value': column_name, 'key': 'chunk_key', 'chunk_index': 'chunk_indices'},
+ )
+ res[column_name] = res[column_name].astype('category')
+ return res
+
+ def _get_property_for_key(self, key: Key, property_name: str) -> Optional[pd.Series]:
+ (column_name,) = key.properties
+ return (
+ self.data.loc[self.data['column_name'] == column_name, property_name]
+ if property_name in self.data.columns
+ else None
+ )
+
+ def keys(self) -> List[Key]:
+ return [Key(properties=(c,), display_names=(c,)) for c in self.column_names]
+
+ def plot(self, drift_result: Optional[DriftResult] = None, *args, **kwargs) -> go.Figure:
+ """
+ Creates a "joyplot over time" visualization to illustrate continuous distribution changes over time.
+
+ Parameters
+ ----------
+ drift_result: Optional[nannyml.drift.univariate.Result]
+ The result of a univariate drift calculation. When set it will be used to lookup alerts that occurred for
+ each column / drift method combination in the drift calculation result.
+ For each of these combinations a distribution plot of the column will be rendered showing the alerts
+ for each drift method.
+ When the `drift_result` parameter is not set no alerts will be rendered on the distribution plots.
+ """
+
+ if drift_result and not isinstance(drift_result, DriftResult):
+ raise InvalidArgumentsException(
+ 'currently the alerts_from parameter only supports results of the ' 'UnivariateDriftCalculator.'
+ )
+
+ if drift_result:
+ self.check_is_compatible_with(drift_result)
+
+ return (
+ _plot_categorical_distribution_with_alerts(self, drift_result)
+ if drift_result
+ else _plot_categorical_distribution(self)
+ )
+
+ def check_is_compatible_with(self, drift_result: DriftResult):
+ # Check if all distribution columns are present in the drift result
+ drift_column_names = set([col for tup in drift_result.keys() for col, _ in tup])
+ distribution_column_names = set(self.column_names)
+
+ missing_columns = distribution_column_names.difference(drift_column_names)
+ if len(missing_columns) > 0:
+ raise InvalidArgumentsException(
+ "cannot render distribution plots with warnings. Following columns are not "
+ f"in the drift results: {list(missing_columns)}"
+ )
+
+ # Check if both results use the same X-axis
+ drift_result_is_time_based = is_time_based_x_axis(drift_result.chunk_start_dates, drift_result.chunk_end_dates)
+ distr_result_is_time_based = is_time_based_x_axis(self.chunk_start_dates, self.chunk_end_dates)
+
+ if drift_result_is_time_based != distr_result_is_time_based:
+ raise InvalidArgumentsException(
+ "cannot render distribution plots with warnings. Drift results are"
+ f"{'' if drift_result_is_time_based else ' not'} time-based, distribution results are"
+ f"{'' if distr_result_is_time_based else ' not'} time-based. Drift and distribution results should "
+ f"both be time-based (have a timestamp column) or not."
+ )
+
+
+def _plot_categorical_distribution(
+ result: Result,
+ title: Optional[str] = 'Column distributions',
+ figure: Optional[go.Figure] = None,
+ x_axis_time_title: str = 'Time',
+ x_axis_chunk_title: str = 'Chunk',
+ y_axis_title: str = 'Values',
+ figure_args: Optional[Dict[str, Any]] = None,
+ subplot_title_format: str = '{display_names[0]} distribution',
+ number_of_columns: Optional[int] = None,
+) -> go.Figure:
+ number_of_plots = len(result.keys())
+ if number_of_columns is None:
+ number_of_columns = min(number_of_plots, 1)
+ number_of_rows = math.ceil(number_of_plots / number_of_columns)
+
+ if figure_args is None:
+ figure_args = {}
+
+ if figure is None:
+ figure = Figure(
+ **dict(
+ title=title,
+ x_axis_title=x_axis_time_title
+ if is_time_based_x_axis(result.chunk_start_dates, result.chunk_end_dates)
+ else x_axis_chunk_title,
+ y_axis_title=y_axis_title,
+ legend=dict(traceorder="grouped", itemclick=False, itemdoubleclick=False),
+ height=number_of_plots * 500 / number_of_columns,
+ subplot_args=dict(
+ cols=number_of_columns,
+ rows=number_of_rows,
+ subplot_titles=[
+ subplot_title_format.format(display_names=key.display_names) for key in result.keys()
+ ],
+ ),
+ **figure_args,
+ )
+ )
+
+ for idx, key in enumerate(result.keys()):
+ row = (idx // number_of_columns) + 1
+ col = (idx % number_of_columns) + 1
+
+ (column_name,) = key.properties
+
+ reference_result = result.filter(period='reference', column_names=[column_name])
+ analysis_result = result.filter(period='analysis', column_names=[column_name])
+
+ figure = _plot_stacked_bar(
+ figure=figure,
+ row=row,
+ col=col,
+ column_name=column_name,
+ reference_value_counts=reference_result.value_counts(key),
+ reference_alerts=None,
+ reference_chunk_keys=reference_result.chunk_keys,
+ reference_chunk_periods=reference_result.chunk_periods,
+ reference_chunk_indices=reference_result.chunk_indices,
+ reference_chunk_start_dates=reference_result.chunk_start_dates,
+ reference_chunk_end_dates=reference_result.chunk_end_dates,
+ analysis_value_counts=analysis_result.value_counts(key),
+ analysis_alerts=None,
+ analysis_chunk_keys=analysis_result.chunk_keys,
+ analysis_chunk_periods=analysis_result.chunk_periods,
+ analysis_chunk_indices=analysis_result.chunk_indices,
+ analysis_chunk_start_dates=analysis_result.chunk_start_dates,
+ analysis_chunk_end_dates=analysis_result.chunk_end_dates,
+ )
+
+ return figure
+
+
+def _plot_categorical_distribution_with_alerts(
+ result: Result,
+ drift_result: DriftResult,
+ title: Optional[str] = 'Column distributions',
+ figure: Optional[go.Figure] = None,
+ x_axis_time_title: str = 'Time',
+ x_axis_chunk_title: str = 'Chunk',
+ y_axis_title: str = 'Values',
+ figure_args: Optional[Dict[str, Any]] = None,
+ subplot_title_format: str = '{display_names[0]} distribution (alerts for {display_names[1]})',
+ number_of_columns: Optional[int] = None,
+) -> go.Figure:
+ number_of_plots = len(drift_result.keys())
+ if number_of_columns is None:
+ number_of_columns = min(number_of_plots, 1)
+ number_of_rows = math.ceil(number_of_plots / number_of_columns)
+
+ if figure_args is None:
+ figure_args = {}
+
+ if figure is None:
+ figure = Figure(
+ **dict(
+ title=title,
+ x_axis_title=x_axis_time_title
+ if is_time_based_x_axis(result.chunk_start_dates, result.chunk_end_dates)
+ else x_axis_chunk_title,
+ y_axis_title=y_axis_title,
+ legend=dict(traceorder="grouped", itemclick=False, itemdoubleclick=False),
+ height=number_of_plots * 500 / number_of_columns,
+ subplot_args=dict(
+ cols=number_of_columns,
+ rows=number_of_rows,
+ subplot_titles=[
+ subplot_title_format.format(display_names=key.display_names) for key in drift_result.keys()
+ ],
+ ),
+ **figure_args,
+ )
+ )
+
+ for idx, drift_key in enumerate(drift_result.keys()):
+ row = (idx // number_of_columns) + 1
+ col = (idx % number_of_columns) + 1
+
+ (column_name, method_name) = drift_key.properties
+
+ reference_result = result.filter(period='reference', column_names=[column_name])
+ reference_result.data.sort_index(inplace=True)
+ analysis_result = result.filter(period='analysis', column_names=[column_name])
+ analysis_result.data.sort_index(inplace=True)
+
+ # reference_alerts = drift_result.filter(period='reference').alerts(drift_key)
+ analysis_alerts = drift_result.filter(period='analysis').alerts(drift_key)
+
+ figure = _plot_stacked_bar(
+ figure=figure,
+ row=row,
+ col=col,
+ column_name=column_name,
+ reference_value_counts=reference_result.value_counts(column_name=column_name),
+ reference_alerts=None,
+ reference_chunk_keys=reference_result.chunk_keys,
+ reference_chunk_periods=reference_result.chunk_periods,
+ reference_chunk_indices=reference_result.chunk_indices,
+ reference_chunk_start_dates=reference_result.chunk_start_dates,
+ reference_chunk_end_dates=reference_result.chunk_end_dates,
+ analysis_value_counts=analysis_result.value_counts(column_name=column_name),
+ analysis_alerts=analysis_alerts,
+ analysis_chunk_keys=analysis_result.chunk_keys,
+ analysis_chunk_periods=analysis_result.chunk_periods,
+ analysis_chunk_indices=analysis_result.chunk_indices,
+ analysis_chunk_start_dates=analysis_result.chunk_start_dates,
+ analysis_chunk_end_dates=analysis_result.chunk_end_dates,
+ )
+
+ return figure
+
+
+def _plot_stacked_bar(
+ figure: Figure,
+ column_name: str,
+ reference_value_counts: pd.DataFrame,
+ analysis_value_counts: pd.DataFrame,
+ reference_alerts: Optional[Union[np.ndarray, pd.Series]] = None,
+ reference_chunk_keys: Optional[Union[np.ndarray, pd.Series]] = None,
+ reference_chunk_periods: Optional[Union[np.ndarray, pd.Series]] = None,
+ reference_chunk_indices: Optional[Union[np.ndarray, pd.Series]] = None,
+ reference_chunk_start_dates: Optional[Union[np.ndarray, pd.Series]] = None,
+ reference_chunk_end_dates: Optional[Union[np.ndarray, pd.Series]] = None,
+ analysis_alerts: Optional[Union[np.ndarray, pd.Series]] = None,
+ analysis_chunk_keys: Optional[Union[np.ndarray, pd.Series]] = None,
+ analysis_chunk_periods: Optional[Union[np.ndarray, pd.Series]] = None,
+ analysis_chunk_indices: Optional[Union[np.ndarray, pd.Series]] = None,
+ analysis_chunk_start_dates: Optional[Union[np.ndarray, pd.Series]] = None,
+ analysis_chunk_end_dates: Optional[Union[np.ndarray, pd.Series]] = None,
+ row: Optional[int] = None,
+ col: Optional[int] = None,
+) -> Figure:
+ is_subplot = row is not None and col is not None
+ subplot_args = dict(row=row, col=col) if is_subplot else None
+
+ has_reference_results = reference_chunk_indices is not None and len(reference_chunk_indices) > 0
+
+ if figure is None:
+ figure = Figure(title='continuous distribution', x_axis_title='time', y_axis_title='value', height=500)
+
+ figure.update_xaxes(
+ dict(mirror=False, showline=False),
+ overwrite=True,
+ matches='x',
+ title=figure.layout.xaxis.title,
+ row=row,
+ col=col,
+ )
+ figure.update_yaxes(
+ dict(mirror=False, showline=False), overwrite=True, title=figure.layout.yaxis.title, row=row, col=col
+ )
+
+ if has_reference_results:
+ figure = stacked_bar(
+ figure=figure,
+ stacked_bar_table=reference_value_counts,
+ color=Colors.BLUE_SKY_CRAYOLA,
+ chunk_indices=reference_chunk_indices,
+ chunk_start_dates=reference_chunk_start_dates,
+ chunk_end_dates=reference_chunk_end_dates,
+ annotation='Reference',
+ showlegend=True,
+ legendgrouptitle_text=f'{column_name}',
+ legendgroup=column_name,
+ subplot_args=subplot_args,
+ )
+
+ assert reference_chunk_indices is not None
+ analysis_chunk_indices = (analysis_chunk_indices + (max(reference_chunk_indices) + 1)).reset_index(drop=True)
+ analysis_value_counts['chunk_indices'] += max(reference_chunk_indices) + 1
+
+ if analysis_chunk_start_dates is not None:
+ analysis_chunk_start_dates = analysis_chunk_start_dates.reset_index(drop=True)
+
+ figure = stacked_bar(
+ figure=figure,
+ stacked_bar_table=analysis_value_counts,
+ color=Colors.INDIGO_PERSIAN,
+ chunk_indices=analysis_chunk_indices,
+ chunk_start_dates=analysis_chunk_start_dates,
+ chunk_end_dates=analysis_chunk_end_dates,
+ annotation='Analysis',
+ showlegend=False,
+ legendgroup=column_name,
+ subplot_args=subplot_args,
+ )
+
+ if analysis_alerts is not None:
+ figure = stacked_bar_alert(
+ figure=figure,
+ alerts=analysis_alerts,
+ stacked_bar_table=analysis_value_counts,
+ color=Colors.RED_IMPERIAL,
+ chunk_indices=analysis_chunk_indices,
+ chunk_start_dates=analysis_chunk_start_dates,
+ chunk_end_dates=analysis_chunk_end_dates,
+ showlegend=True,
+ legendgroup=column_name,
+ subplot_args=subplot_args,
+ )
+
+ return figure
diff --git a/nannyml/distribution/continuous/__init__.py b/nannyml/distribution/continuous/__init__.py
new file mode 100644
index 000000000..cf73eb42e
--- /dev/null
+++ b/nannyml/distribution/continuous/__init__.py
@@ -0,0 +1 @@
+from .calculator import ContinuousDistributionCalculator
diff --git a/nannyml/distribution/continuous/calculator.py b/nannyml/distribution/continuous/calculator.py
new file mode 100644
index 000000000..9f4c45315
--- /dev/null
+++ b/nannyml/distribution/continuous/calculator.py
@@ -0,0 +1,226 @@
+from functools import partial
+from typing import List, Optional, Union
+
+import numpy as np
+import pandas as pd
+from scipy.integrate import cumulative_trapezoid
+from statsmodels import api as sm
+
+from nannyml import Chunker
+from nannyml._typing import Self
+from nannyml.base import AbstractCalculator, _list_missing
+from nannyml.distribution.continuous.result import Result
+from nannyml.exceptions import InvalidArgumentsException
+
+
+class ContinuousDistributionCalculator(AbstractCalculator):
+ def __init__(
+ self,
+ column_names: Union[str, List[str]],
+ timestamp_column_name: Optional[str] = None,
+ chunk_size: Optional[int] = None,
+ chunk_number: Optional[int] = None,
+ chunk_period: Optional[str] = None,
+ chunker: Optional[Chunker] = None,
+ points_per_joy_plot: Optional[int] = None,
+ ):
+ super().__init__(
+ chunk_size,
+ chunk_number,
+ chunk_period,
+ chunker,
+ timestamp_column_name,
+ )
+
+ self.column_names = column_names if isinstance(column_names, List) else [column_names]
+ self.result: Optional[Result] = None
+ self.points_per_joy_plot = points_per_joy_plot
+
+ def _fit(self, reference_data: pd.DataFrame, *args, **kwargs) -> Self:
+ self.result = self._calculate(reference_data)
+ self.result.data[('chunk', 'period')] = 'reference'
+
+ return self
+
+ def _calculate(self, data: pd.DataFrame, *args, **kwargs) -> Result:
+ if data.empty:
+ raise InvalidArgumentsException('data contains no rows. Please provide a valid data set.')
+
+ _list_missing(self.column_names, data)
+
+ result_data = pd.DataFrame(columns=_create_multilevel_index(self.column_names))
+
+ for column in self.column_names:
+ column_distributions_per_chunk = calculate_chunk_distributions(
+ data[column],
+ self.chunker,
+ data.get(self.timestamp_column_name, default=None),
+ points_per_joy_plot=self.points_per_joy_plot,
+ )
+ column_distributions_per_chunk.drop(columns=['key', 'chunk_index'], inplace=True)
+ for c in column_distributions_per_chunk.columns:
+ result_data.loc[:, (column, c)] = column_distributions_per_chunk[c]
+
+ chunks = self.chunker.split(data)
+ result_data[('chunk', 'key')] = [c.key for c in chunks]
+ result_data[('chunk', 'chunk_index')] = [c.chunk_index for c in chunks]
+ result_data[('chunk', 'start_index')] = [c.start_index for c in chunks]
+ result_data[('chunk', 'end_index')] = [c.end_index for c in chunks]
+ result_data[('chunk', 'start_date')] = [c.start_datetime for c in chunks]
+ result_data[('chunk', 'end_date')] = [c.end_datetime for c in chunks]
+ result_data[('chunk', 'period')] = ['analysis' for _ in chunks]
+
+ if self.result is None:
+ self.result = Result(result_data, self.column_names, self.timestamp_column_name, self.chunker)
+ else:
+ self.result = self.result.filter(period='reference')
+ self.result.data = pd.concat([self.result.data, result_data]).reset_index(drop=True)
+
+ return self.result
+
+
+def _get_kde(array, cut=3, clip=(-np.inf, np.inf)):
+ try: # pragma: no cover
+ kde = sm.nonparametric.KDEUnivariate(array)
+ kde.fit(cut=cut, clip=clip)
+ return kde
+ except Exception:
+ return None
+
+
+def _get_kde_support(kde, points_per_joy_plot: Optional[int] = None):
+ if kde is not None: # pragma: no cover
+ return kde.support[:: (len(kde.support) // (points_per_joy_plot or 50))]
+ else:
+ return np.array([])
+
+
+def _get_kde_density(kde, points_per_joy_plot: Optional[int] = None):
+ if kde is not None: # pragma: no cover
+ return kde.density[:: (len(kde.support) // (points_per_joy_plot or 50))]
+ else:
+ return np.array([])
+
+
+def _get_kde_cdf(kde_support, kde_density):
+ if len(kde_support) > 0 and len(kde_density) > 0:
+ cdf = cumulative_trapezoid(y=kde_density, x=kde_support, initial=0)
+ return cdf
+ else:
+ return np.array([])
+
+
+def _get_kde_quartiles(cdf, kde_support, kde_density):
+ if len(cdf) > 0:
+ quartiles = []
+ for quartile in [0.25, 0.50, 0.75]:
+ quartile_index = np.argmax(cdf >= quartile)
+ quartiles.append((kde_support[quartile_index], kde_density[quartile_index], cdf[quartile_index]))
+ return quartiles
+ else:
+ return []
+
+
+def calculate_chunk_distributions(
+ data: Union[np.ndarray, pd.Series],
+ chunker: Chunker,
+ timestamps: Optional[Union[np.ndarray, pd.Series]] = None,
+ data_periods: Optional[Union[np.ndarray, pd.Series]] = None,
+ kde_cut=3,
+ kde_clip=(-np.inf, np.inf),
+ post_kde_clip=None,
+ points_per_joy_plot: Optional[int] = None,
+):
+ if isinstance(data, np.ndarray):
+ data = pd.Series(data, name='data')
+
+ if isinstance(data_periods, np.ndarray):
+ data_periods = pd.Series(data_periods, name='period')
+
+ get_kde_partial_application = partial(_get_kde, cut=kde_cut, clip=kde_clip)
+
+ data_with_chunk_keys = pd.concat(
+ [
+ chunk.data.assign(key=chunk.key, chunk_index=chunk.chunk_index)
+ for chunk in chunker.split(pd.concat([data, timestamps], axis=1))
+ ]
+ )
+
+ group_by_cols = ['chunk_index', 'key']
+ if data_periods is not None:
+ data_with_chunk_keys['period'] = data_periods
+ group_by_cols += ['period']
+ data = (
+ # group by period too, 'key' column can be there for both reference and analysis
+ data_with_chunk_keys.groupby(group_by_cols)[data.name]
+ .apply(get_kde_partial_application)
+ .to_frame('kde')
+ .reset_index()
+ )
+
+ data['kde_support'] = data['kde'].apply(lambda kde: _get_kde_support(kde, points_per_joy_plot))
+ data['kde_density'] = data['kde'].apply(lambda kde: _get_kde_density(kde, points_per_joy_plot))
+ data['kde_cdf'] = data[['kde_support', 'kde_density']].apply(
+ lambda row: _get_kde_cdf(row['kde_support'], row['kde_density'] if len(row['kde_support']) > 0 else []),
+ axis=1,
+ )
+
+ if post_kde_clip:
+ # Clip the kde support to the clip values, adjust the density and cdf to the same length
+ data['kde_support'] = data['kde_support'].apply(lambda x: x[x > post_kde_clip[0]])
+ data['kde_support_len'] = data['kde_support'].apply(lambda x: len(x))
+ data['kde_density'] = data.apply(lambda row: row['kde_density'][-row['kde_support_len'] :], axis=1)
+ data['kde_cdf'] = data.apply(lambda row: row['kde_cdf'][-row['kde_support_len'] :], axis=1)
+ data['kde_support'] = data['kde_support'].apply(lambda x: x[x < post_kde_clip[1]])
+ data['kde_support_len'] = data['kde_support'].apply(lambda x: len(x))
+ data['kde_density'] = data.apply(lambda row: row['kde_density'][: row['kde_support_len']], axis=1)
+ data['kde_cdf'] = data.apply(lambda row: row['kde_cdf'][: row['kde_support_len']], axis=1)
+ data['kde_support_len'] = data['kde_support'].apply(lambda x: len(x))
+
+ data['kde_support_len'] = data['kde_support'].apply(lambda x: len(x))
+ data['kde_quartiles'] = data[['kde_cdf', 'kde_support', 'kde_density']].apply(
+ lambda row: _get_kde_quartiles(
+ row['kde_cdf'], row['kde_support'], row['kde_density'] if len(row['kde_support']) > 0 else []
+ ),
+ axis=1,
+ )
+ data['kde_density_local_max'] = data['kde_density'].apply(lambda x: max(x) if len(x) > 0 else 0)
+ data['kde_density_global_max'] = data.groupby('chunk_index')['kde_density_local_max'].max().max()
+ data['kde_density_scaled'] = data[['kde_density', 'kde_density_global_max']].apply(
+ lambda row: np.divide(np.array(row['kde_density']), row['kde_density_global_max']), axis=1
+ )
+ data['kde_quartiles_scaled'] = data[['kde_quartiles', 'kde_density_global_max']].apply(
+ lambda row: [(q[0], q[1] / row['kde_density_global_max'], q[2]) for q in row['kde_quartiles']], axis=1
+ )
+
+ # TODO: Consider removing redundant columns to reduce fitted calculator memory usage
+ # The kde calculator creates issues for pickling the calculator. We don't need it anymore, so removing it here
+ del data['kde']
+
+ return data
+
+
+def _create_multilevel_index(column_names: List[str]):
+ chunk_column_names = ['key', 'chunk_index', 'start_index', 'end_index', 'start_date', 'end_date', 'period']
+ distribution_column_names = [
+ 'kde',
+ 'kde_support',
+ 'kde_density',
+ 'kde_cdf',
+ 'kde_support_len',
+ 'kde_quartiles',
+ 'kde_density_local_max',
+ 'kde_density_global_max',
+ 'kde_density_scaled',
+ 'kde_quartiles_scaled',
+ ]
+ chunk_tuples = [('chunk', chunk_column_name) for chunk_column_name in chunk_column_names]
+ continuous_column_tuples = [
+ (column_name, distribution_column_name)
+ for column_name in column_names
+ for distribution_column_name in distribution_column_names
+ ]
+
+ tuples = chunk_tuples + continuous_column_tuples
+
+ return pd.MultiIndex.from_tuples(tuples)
diff --git a/nannyml/distribution/continuous/result.py b/nannyml/distribution/continuous/result.py
new file mode 100644
index 000000000..42788bdd5
--- /dev/null
+++ b/nannyml/distribution/continuous/result.py
@@ -0,0 +1,310 @@
+import math
+from typing import Any, Dict, List, Optional, Union
+
+import numpy as np
+import pandas as pd
+import plotly.graph_objects as go
+
+from nannyml import Chunker
+from nannyml._typing import Key
+from nannyml.base import PerColumnResult
+from nannyml.drift.univariate.result import Result as DriftResult
+from nannyml.exceptions import InvalidArgumentsException
+from nannyml.plots import Colors, Figure, Hover, is_time_based_x_axis
+from nannyml.plots.components.joy_plot import alert as joy_alert
+from nannyml.plots.components.joy_plot import joy
+
+
+class Result(PerColumnResult):
+ def __init__(
+ self,
+ results_data: pd.DataFrame,
+ column_names: List[str],
+ timestamp_column_name: Optional[str],
+ chunker: Chunker,
+ ):
+ super().__init__(results_data, column_names)
+
+ self.timestamp_column_name = timestamp_column_name
+ self.chunker = chunker
+
+ def keys(self) -> List[Key]:
+ return [Key(properties=(c,), display_names=(c,)) for c in self.column_names]
+
+ def plot(self, drift_result: Optional[DriftResult] = None, *args, **kwargs) -> go.Figure:
+ """
+ Creates a "joyplot over time" visualization to illustrate continuous distribution changes over time.
+
+ Parameters
+ ----------
+ drift_result: Optional[nannyml.drift.univariate.Result]
+ The result of a univariate drift calculation. When set it will be used to lookup alerts that occurred for
+ each column / drift method combination in the drift calculation result.
+ For each of these combinations a distribution plot of the column will be rendered showing the alerts
+ for each drift method.
+ When the `drift_result` parameter is not set no alerts will be rendered on the distribution plots.
+ """
+
+ if drift_result and not isinstance(drift_result, DriftResult):
+ raise InvalidArgumentsException(
+ 'currently the alerts_from parameter only supports results of the ' 'UnivariateDriftCalculator.'
+ )
+
+ if drift_result:
+ self.check_is_compatible_with(drift_result)
+
+ return (
+ _plot_continuous_distribution_with_alerts(self, drift_result)
+ if drift_result
+ else _plot_continuous_distribution(self)
+ )
+
+ def check_is_compatible_with(self, drift_result: DriftResult):
+ # Check if all distribution columns are present in the drift result
+ drift_column_names = set([col for tup in drift_result.keys() for col, _ in tup])
+ distribution_column_names = set(self.column_names)
+
+ missing_columns = distribution_column_names.difference(drift_column_names)
+ if len(missing_columns) > 0:
+ raise InvalidArgumentsException(
+ "cannot render distribution plots with warnings. Following columns are not "
+ f"in the drift results: {list(missing_columns)}"
+ )
+
+ # Check if both results use the same X-axis
+ drift_result_is_time_based = is_time_based_x_axis(drift_result.chunk_start_dates, drift_result.chunk_end_dates)
+ distr_result_is_time_based = is_time_based_x_axis(self.chunk_start_dates, self.chunk_end_dates)
+
+ if drift_result_is_time_based != distr_result_is_time_based:
+ raise InvalidArgumentsException(
+ "cannot render distribution plots with warnings. Drift results are"
+ f"{'' if drift_result_is_time_based else ' not'} time-based, distribution results are"
+ f"{'' if distr_result_is_time_based else ' not'} time-based. Drift and distribution results should "
+ f"both be time-based (have a timestamp column) or not."
+ )
+
+
+def _plot_continuous_distribution(
+ result: Result,
+ title: Optional[str] = 'Column distributions',
+ figure: Optional[go.Figure] = None,
+ x_axis_time_title: str = 'Time',
+ x_axis_chunk_title: str = 'Chunk',
+ y_axis_title: str = 'Values',
+ figure_args: Optional[Dict[str, Any]] = None,
+ subplot_title_format: str = '{display_names[0]} distribution',
+ number_of_columns: Optional[int] = None,
+) -> go.Figure:
+ number_of_plots = len(result.keys())
+ if number_of_columns is None:
+ number_of_columns = min(number_of_plots, 1)
+ number_of_rows = math.ceil(number_of_plots / number_of_columns)
+
+ if figure_args is None:
+ figure_args = {}
+
+ if figure is None:
+ figure = Figure(
+ **dict(
+ title=title,
+ x_axis_title=x_axis_time_title
+ if is_time_based_x_axis(result.chunk_start_dates, result.chunk_end_dates)
+ else x_axis_chunk_title,
+ y_axis_title=y_axis_title,
+ legend=dict(traceorder="grouped", itemclick=False, itemdoubleclick=False),
+ height=number_of_plots * 500 / number_of_columns,
+ subplot_args=dict(
+ cols=number_of_columns,
+ rows=number_of_rows,
+ subplot_titles=[
+ subplot_title_format.format(display_names=key.display_names) for key in result.keys()
+ ],
+ ),
+ **figure_args,
+ )
+ )
+
+ reference_result = result.filter(period='reference')
+ analysis_result = result.filter(period='analysis')
+
+ for idx, key in enumerate(result.keys()):
+ row = (idx // number_of_columns) + 1
+ col = (idx % number_of_columns) + 1
+
+ (column_name,) = key.properties
+
+ figure = _plot_joyplot(
+ figure=figure,
+ row=row,
+ col=col,
+ metric_display_name='',
+ reference_distributions=reference_result.to_df().loc[:, (column_name,)],
+ reference_alerts=None,
+ reference_chunk_keys=reference_result.chunk_keys,
+ reference_chunk_periods=reference_result.chunk_periods,
+ reference_chunk_indices=reference_result.chunk_indices,
+ reference_chunk_start_dates=reference_result.chunk_start_dates,
+ reference_chunk_end_dates=reference_result.chunk_end_dates,
+ analysis_distributions=analysis_result.to_df().loc[:, (column_name,)],
+ analysis_alerts=None,
+ analysis_chunk_keys=analysis_result.chunk_keys,
+ analysis_chunk_periods=analysis_result.chunk_periods,
+ analysis_chunk_indices=analysis_result.chunk_indices,
+ analysis_chunk_start_dates=analysis_result.chunk_start_dates,
+ analysis_chunk_end_dates=analysis_result.chunk_end_dates,
+ )
+
+ return figure
+
+
+def _plot_continuous_distribution_with_alerts(
+ result: Result,
+ drift_result: DriftResult,
+ title: Optional[str] = 'Column distributions',
+ figure: Optional[go.Figure] = None,
+ x_axis_time_title: str = 'Time',
+ x_axis_chunk_title: str = 'Chunk',
+ y_axis_title: str = 'Values',
+ figure_args: Optional[Dict[str, Any]] = None,
+ subplot_title_format: str = '{display_names[0]} distribution (alerts for {display_names[1]})',
+ number_of_columns: Optional[int] = None,
+) -> go.Figure:
+ number_of_plots = len(drift_result.keys())
+ if number_of_columns is None:
+ number_of_columns = min(number_of_plots, 1)
+ number_of_rows = math.ceil(number_of_plots / number_of_columns)
+
+ if figure_args is None:
+ figure_args = {}
+
+ if figure is None:
+ figure = Figure(
+ **dict(
+ title=title,
+ x_axis_title=x_axis_time_title
+ if is_time_based_x_axis(result.chunk_start_dates, result.chunk_end_dates)
+ else x_axis_chunk_title,
+ y_axis_title=y_axis_title,
+ legend=dict(traceorder="grouped", itemclick=False, itemdoubleclick=False),
+ height=number_of_plots * 500 / number_of_columns,
+ subplot_args=dict(
+ cols=number_of_columns,
+ rows=number_of_rows,
+ subplot_titles=[
+ subplot_title_format.format(display_names=key.display_names) for key in drift_result.keys()
+ ],
+ ),
+ **figure_args,
+ )
+ )
+
+ reference_result = result.filter(period='reference')
+ reference_result.data.sort_index(inplace=True)
+ analysis_result = result.filter(period='analysis')
+ analysis_result.data.sort_index(inplace=True)
+
+ for idx, drift_key in enumerate(drift_result.keys()):
+ row = (idx // number_of_columns) + 1
+ col = (idx % number_of_columns) + 1
+
+ (column_name, method_name) = drift_key.properties
+
+ # reference_alerts = drift_result.filter(period='reference').alerts(drift_key)
+ analysis_alerts = drift_result.filter(period='analysis').alerts(drift_key)
+
+ figure = _plot_joyplot(
+ figure=figure,
+ row=row,
+ col=col,
+ metric_display_name='',
+ reference_distributions=reference_result.to_df().xs(column_name, level=0, axis=1),
+ reference_alerts=None,
+ reference_chunk_keys=reference_result.chunk_keys,
+ reference_chunk_periods=reference_result.chunk_periods,
+ reference_chunk_indices=reference_result.chunk_indices,
+ reference_chunk_start_dates=reference_result.chunk_start_dates,
+ reference_chunk_end_dates=reference_result.chunk_end_dates,
+ analysis_distributions=analysis_result.to_df().xs(column_name, level=0, axis=1),
+ analysis_alerts=analysis_alerts,
+ analysis_chunk_keys=analysis_result.chunk_keys,
+ analysis_chunk_periods=analysis_result.chunk_periods,
+ analysis_chunk_indices=analysis_result.chunk_indices,
+ analysis_chunk_start_dates=analysis_result.chunk_start_dates,
+ analysis_chunk_end_dates=analysis_result.chunk_end_dates,
+ )
+
+ return figure
+
+
+def _plot_joyplot(
+ figure: go.Figure,
+ metric_display_name: str,
+ reference_distributions: pd.DataFrame,
+ analysis_distributions: pd.DataFrame,
+ reference_alerts: Optional[Union[np.ndarray, pd.Series]] = None,
+ reference_chunk_keys: Optional[Union[np.ndarray, pd.Series]] = None,
+ reference_chunk_periods: Optional[Union[np.ndarray, pd.Series]] = None,
+ reference_chunk_indices: Optional[Union[np.ndarray, pd.Series]] = None,
+ reference_chunk_start_dates: Optional[Union[np.ndarray, pd.Series]] = None,
+ reference_chunk_end_dates: Optional[Union[np.ndarray, pd.Series]] = None,
+ analysis_alerts: Optional[Union[np.ndarray, pd.Series]] = None,
+ analysis_chunk_keys: Optional[Union[np.ndarray, pd.Series]] = None,
+ analysis_chunk_periods: Optional[Union[np.ndarray, pd.Series]] = None,
+ analysis_chunk_indices: Optional[Union[np.ndarray, pd.Series]] = None,
+ analysis_chunk_start_dates: Optional[Union[np.ndarray, pd.Series]] = None,
+ analysis_chunk_end_dates: Optional[Union[np.ndarray, pd.Series]] = None,
+ row: Optional[int] = None,
+ col: Optional[int] = None,
+ hover: Optional[Hover] = None,
+) -> go.Figure:
+ is_subplot = row is not None and col is not None
+ subplot_args = dict(row=row, col=col) if is_subplot else None
+
+ has_reference_results = reference_chunk_indices is not None and len(reference_chunk_indices) > 0
+
+ if figure is None:
+ figure = go.Figure(title='continuous distribution', x_axis_title='time', y_axis_title='value', height=500)
+
+ if has_reference_results: # TODO: move distribution calculations to calculator run
+ figure = joy(
+ fig=figure,
+ data_distributions=reference_distributions,
+ chunk_keys=reference_chunk_keys,
+ chunk_indices=reference_chunk_indices,
+ chunk_start_dates=reference_chunk_start_dates,
+ chunk_end_dates=reference_chunk_end_dates,
+ name='Reference',
+ color=Colors.BLUE_SKY_CRAYOLA,
+ subplot_args=subplot_args,
+ )
+
+ assert reference_chunk_indices is not None
+ analysis_chunk_indices = analysis_chunk_indices + (max(reference_chunk_indices) + 1)
+
+ figure = joy(
+ fig=figure,
+ data_distributions=analysis_distributions,
+ chunk_keys=analysis_chunk_keys,
+ chunk_indices=analysis_chunk_indices,
+ chunk_start_dates=analysis_chunk_start_dates,
+ chunk_end_dates=analysis_chunk_end_dates,
+ name='Analysis',
+ color=Colors.INDIGO_PERSIAN,
+ subplot_args=subplot_args,
+ )
+
+ if analysis_alerts is not None:
+ figure = joy_alert(
+ fig=figure,
+ alerts=analysis_alerts,
+ data_distributions=analysis_distributions,
+ color=Colors.RED_IMPERIAL,
+ name='Alerts',
+ chunk_keys=analysis_chunk_keys,
+ chunk_indices=analysis_chunk_indices,
+ chunk_start_dates=analysis_chunk_start_dates,
+ chunk_end_dates=analysis_chunk_end_dates,
+ subplot_args=subplot_args,
+ )
+
+ return figure
diff --git a/nannyml/plots/components/joy_plot.py b/nannyml/plots/components/joy_plot.py
index 69b0d970f..554a562cf 100644
--- a/nannyml/plots/components/joy_plot.py
+++ b/nannyml/plots/components/joy_plot.py
@@ -150,6 +150,7 @@ def joy(
data_distributions: pd.DataFrame,
color: str,
name: str,
+ chunk_keys: Optional[Union[np.ndarray, pd.Series]] = None,
chunk_start_dates: Optional[Union[np.ndarray, pd.Series]] = None,
chunk_end_dates: Optional[Union[np.ndarray, pd.Series]] = None,
chunk_indices: Optional[Union[np.ndarray, pd.Series]] = None,
@@ -158,7 +159,9 @@ def joy(
plot_quartiles: bool = True,
**kwargs,
) -> go.Figure:
- chunk_indices, chunk_start_dates, chunk_end_dates = ensure_numpy(chunk_indices, chunk_start_dates, chunk_end_dates)
+ chunk_keys, chunk_indices, chunk_start_dates, chunk_end_dates = ensure_numpy(
+ chunk_keys, chunk_indices, chunk_start_dates, chunk_end_dates
+ )
joy_overlap = 1
if subplot_args is None:
@@ -228,7 +231,7 @@ def joy(
if plot_quartiles:
for kde_quartile in kde_quartiles:
hover = Hover(template='Chunk %{chunk_key}: %{x_coordinate}, %{quartile}')
- hover.add(row['chunk_key'], name='chunk_key')
+ hover.add(chunk_keys[i] if chunk_keys is not None else row['chunk_key'], name='chunk_key')
hover.add(
render_x_coordinate(chunk_indices, chunk_start_dates, chunk_end_dates)[i], name='x_coordinate'
)
@@ -259,6 +262,7 @@ def alert(
color: str,
name: str,
alerts: Union[np.ndarray, pd.Series],
+ chunk_keys: Optional[Union[np.ndarray, pd.Series]] = None,
chunk_start_dates: Optional[Union[np.ndarray, pd.Series]] = None,
chunk_end_dates: Optional[Union[np.ndarray, pd.Series]] = None,
chunk_indices: Optional[Union[np.ndarray, pd.Series]] = None,
@@ -269,6 +273,7 @@ def alert(
) -> go.Figure:
data = pd.DataFrame(
{
+ 'chunk_keys': chunk_keys,
'chunk_indices': chunk_indices,
'chunk_start_dates': chunk_start_dates,
'chunk_end_dates': chunk_end_dates,
@@ -282,6 +287,7 @@ def alert(
alerts_data[data_distributions.columns],
color,
name,
+ alerts_data['chunk_keys'],
alerts_data['chunk_start_dates'],
alerts_data['chunk_end_dates'],
alerts_data['chunk_indices'],
diff --git a/nannyml/plots/components/stacked_bar_plot.py b/nannyml/plots/components/stacked_bar_plot.py
index 89ba93322..9d0d8bd08 100644
--- a/nannyml/plots/components/stacked_bar_plot.py
+++ b/nannyml/plots/components/stacked_bar_plot.py
@@ -113,13 +113,21 @@ def stacked_bar(
data = stacked_bar_table.loc[stacked_bar_table[column_name] == category]
if is_time_based_x_axis(chunk_start_dates, chunk_end_dates):
- x = chunk_start_dates
+ x = data.get('start_datetime', chunk_start_dates)
else:
- x = chunk_indices
+ x = data.get('chunk_indices', chunk_indices)
hover = Hover(template="Chunk %{chunk_key}: %{x_coordinate}; (%{value_counts_normalised}, %{value_counts})")
hover.add(data['chunk_key'], name='chunk_key')
- hover.add(render_x_coordinate(data['chunk_indices'], chunk_start_dates, chunk_end_dates), name='x_coordinate')
+ hover.add(
+ render_x_coordinate(
+ data['chunk_indices'],
+ data.get('start_datetime', chunk_start_dates),
+ data.get('end_datetime', chunk_end_dates),
+ ),
+ name='x_coordinate',
+ )
+ # hover.add(render_x_coordinate(data['chunk_indices'], chunk_start_dates, chunk_end_dates), name='x_coordinate')
hover.add(data['value_counts_normalised'], name='value_counts_normalised')
hover.add(data['value_counts'], name='value_counts')
diff --git a/nannyml/runner.py b/nannyml/runner.py
index 21221dbbe..38e395de7 100644
--- a/nannyml/runner.py
+++ b/nannyml/runner.py
@@ -16,6 +16,8 @@
from nannyml.config import Config, InputDataConfig, StoreConfig, WriterConfig
from nannyml.data_quality.missing import MissingValuesCalculator
from nannyml.data_quality.unseen import UnseenValuesCalculator
+from nannyml.distribution.categorical import CategoricalDistributionCalculator
+from nannyml.distribution.continuous import ContinuousDistributionCalculator
from nannyml.drift.multivariate.data_reconstruction import DataReconstructionDriftCalculator
from nannyml.drift.univariate import UnivariateDriftCalculator
from nannyml.exceptions import InvalidArgumentsException
@@ -78,6 +80,8 @@ class CalculatorFactory:
'dle': DLE,
'missing_values': MissingValuesCalculator,
'unseen_values': UnseenValuesCalculator,
+ 'continuous_distribution': ContinuousDistributionCalculator,
+ 'categorical_distribution': CategoricalDistributionCalculator,
'summary_stats_avg': SummaryStatsAvgCalculator,
'summary_stats_row_count': SummaryStatsRowCountCalculator,
'summary_stats_median': SummaryStatsMedianCalculator,