From ba1d02331737a9f28e11b73ae05a7abd1c1c8aff Mon Sep 17 00:00:00 2001 From: Niels Nuyttens Date: Thu, 16 May 2024 10:55:05 +0100 Subject: [PATCH] Use InvalidArgumentsException, some renames --- nannyml/base.py | 18 +++++++------- tests/test_base.py | 59 ++++++++++++++-------------------------------- 2 files changed, 28 insertions(+), 49 deletions(-) diff --git a/nannyml/base.py b/nannyml/base.py index da4b35ea..9cc61326 100644 --- a/nannyml/base.py +++ b/nannyml/base.py @@ -636,7 +636,7 @@ def _common_nan_removal_dataframe(data: pd.DataFrame, selected_columns: List[str Boolean whether the resulting data are contain any rows (false) or not (true) """ if not set(selected_columns) <= set(data.columns): - raise ValueError( + raise InvalidArgumentsException( f"Selected columns: {selected_columns} not all present in provided data columns {list(data.columns)}" ) df = data.dropna(axis=0, how='any', inplace=False, subset=selected_columns).reset_index(drop=True).infer_objects() @@ -646,12 +646,12 @@ def _common_nan_removal_dataframe(data: pd.DataFrame, selected_columns: List[str def _common_nan_removal_ndarrays(data: Sequence[np.array], selected_columns: List[int]) -> Tuple[pd.DataFrame, bool]: """ - Remove rows of numpy ndarrays containing NaN values on selected columns. + Remove rows of numpy arrays containing NaN values on selected columns. Parameters ---------- data: Sequence[np.array] - Sequence containing numpy ndarrays. + Sequence containing numpy arrays. selected_columns: List[int] List containing the indices of column numbers @@ -665,7 +665,7 @@ def _common_nan_removal_ndarrays(data: Sequence[np.array], selected_columns: Lis """ # Check if all selected_columns indices are valid for the first ndarray if not all(col < len(data) for col in selected_columns): - raise ValueError( + raise InvalidArgumentsException( f"Selected columns: {selected_columns} not all present in provided data columns with shape {data[0].shape}" ) @@ -680,22 +680,24 @@ def _common_nan_removal_ndarrays(data: Sequence[np.array], selected_columns: Lis @overload -def common_nan_removal(data: pd.DataFrame, selected_columns: List[str]) -> Tuple[pd.DataFrame, bool]: ... +def common_nan_removal(data: pd.DataFrame, selected_columns: List[str]) -> Tuple[pd.DataFrame, bool]: + ... @overload -def common_nan_removal(data: Sequence[np.array], selected_columns: List[int]) -> Tuple[pd.DataFrame, bool]: ... +def common_nan_removal(data: Sequence[np.array], selected_columns: List[int]) -> Tuple[pd.DataFrame, bool]: + ... def common_nan_removal( - data: Union[pd.DataFrame, Sequence[np.array]], selected_columns: Union[List[str], List[int]] + data: Union[pd.DataFrame, Sequence[np.array]], selected_columns: Union[List[str], List[int]] ) -> Tuple[pd.DataFrame, bool]: """ Wrapper function to handle both pandas DataFrame and sequences of numpy ndarrays. Parameters ---------- - data: Union[pd.DataFrame, Sequence[np.ndarray]] + data: Union[pd.DataFrame, Sequence[np.array]] Pandas dataframe or sequence of numpy ndarrays containing data. selected_columns: Union[List[str], List[int]] List containing the column names or indices diff --git a/tests/test_base.py b/tests/test_base.py index b2a69e16..d553f725 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -1,36 +1,24 @@ - import numpy as np import pandas as pd import pytest from nannyml.base import common_nan_removal +from nannyml.exceptions import InvalidArgumentsException def test_common_nan_removal_dataframe(): - data = pd.DataFrame({ - 'A': [1, 2, np.nan, 4], - 'B': [5, np.nan, 7, 8], - 'C': [9, 10, 11, np.nan] - }) + data = pd.DataFrame({'A': [1, 2, np.nan, 4], 'B': [5, np.nan, 7, 8], 'C': [9, 10, 11, np.nan]}) selected_columns = ['A', 'B'] df_cleaned, is_empty = common_nan_removal(data, selected_columns) - expected_df = pd.DataFrame({ - 'A': [1, 4], - 'B': [5, 8], - 'C': [9, np.nan] - }).reset_index(drop=True) + expected_df = pd.DataFrame({'A': [1, 4], 'B': [5, 8], 'C': [9, np.nan]}).reset_index(drop=True) pd.testing.assert_frame_equal(df_cleaned, expected_df, check_dtype=False) # ignore types because of infer_objects assert not is_empty def test_common_nan_removal_dataframe_all_nan(): - data = pd.DataFrame({ - 'A': [np.nan, np.nan], - 'B': [np.nan, np.nan], - 'C': [np.nan, np.nan] - }) + data = pd.DataFrame({'A': [np.nan, np.nan], 'B': [np.nan, np.nan], 'C': [np.nan, np.nan]}) selected_columns = ['A', 'B'] df_cleaned, is_empty = common_nan_removal(data, selected_columns) @@ -40,23 +28,20 @@ def test_common_nan_removal_dataframe_all_nan(): assert is_empty -def test_common_nan_removal_ndarrays(): - data = [ - np.array([1, 5, 9]), - np.array([2, np.nan, 10]), - np.array([np.nan, 7, 11]), - np.array([4, 8, np.nan]) - ] +def test_common_nan_removal_arrays(): + data = [np.array([1, 5, 9]), np.array([2, np.nan, 10]), np.array([np.nan, 7, 11]), np.array([4, 8, np.nan])] selected_columns_indices = [0, 1] # Corresponds to columns 'A' and 'B' df_cleaned, is_empty = common_nan_removal(data, selected_columns_indices) - expected_df = pd.DataFrame({ - 'col_0': [1, 9], - 'col_1': [2, 10], - 'col_2': [np.nan, 11], - 'col_3': [4, np.nan], - }).reset_index(drop=True) + expected_df = pd.DataFrame( + { + 'col_0': [1, 9], + 'col_1': [2, 10], + 'col_2': [np.nan, 11], + 'col_3': [4, np.nan], + } + ).reset_index(drop=True) pd.testing.assert_frame_equal(df_cleaned, expected_df, check_dtype=False) assert not is_empty @@ -67,28 +52,21 @@ def test_common_nan_removal_arrays_all_nan(): np.array([np.nan, np.nan]), np.array([np.nan, np.nan]), np.array([np.nan, np.nan]), - ] selected_columns_indices = [0, 1] # Corresponds to columns 'A' and 'B' df_cleaned, is_empty = common_nan_removal(data, selected_columns_indices) - expected_df = pd.DataFrame(columns=[ - 'col_0', 'col_1', 'col_2' - ]) + expected_df = pd.DataFrame(columns=['col_0', 'col_1', 'col_2']) pd.testing.assert_frame_equal(df_cleaned, expected_df, check_index_type=False, check_dtype=False) assert is_empty def test_invalid_dataframe_columns(): - data = pd.DataFrame({ - 'A': [1, 2, np.nan, 4], - 'B': [5, np.nan, 7, 8], - 'C': [9, 10, 11, np.nan] - }) + data = pd.DataFrame({'A': [1, 2, np.nan, 4], 'B': [5, np.nan, 7, 8], 'C': [9, 10, 11, np.nan]}) selected_columns = ['A', 'D'] # 'D' does not exist - with pytest.raises(ValueError): + with pytest.raises(InvalidArgumentsException): common_nan_removal(data, selected_columns) @@ -97,9 +75,8 @@ def test_invalid_array_columns(): np.array([np.nan, np.nan]), np.array([np.nan, np.nan]), np.array([np.nan, np.nan]), - ] selected_columns_indices = [0, 3] # Index 3 does not exist in ndarray - with pytest.raises(ValueError): + with pytest.raises(InvalidArgumentsException): common_nan_removal(data, selected_columns_indices)