Skip to content

Commit

Permalink
Use InvalidArgumentsException, some renames
Browse files Browse the repository at this point in the history
  • Loading branch information
nnansters committed May 16, 2024
1 parent 558b5d7 commit ba1d023
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 49 deletions.
18 changes: 10 additions & 8 deletions nannyml/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,7 @@ def _common_nan_removal_dataframe(data: pd.DataFrame, selected_columns: List[str
Boolean whether the resulting data are contain any rows (false) or not (true)
"""
if not set(selected_columns) <= set(data.columns):
raise ValueError(
raise InvalidArgumentsException(
f"Selected columns: {selected_columns} not all present in provided data columns {list(data.columns)}"
)
df = data.dropna(axis=0, how='any', inplace=False, subset=selected_columns).reset_index(drop=True).infer_objects()
Expand All @@ -646,12 +646,12 @@ def _common_nan_removal_dataframe(data: pd.DataFrame, selected_columns: List[str

def _common_nan_removal_ndarrays(data: Sequence[np.array], selected_columns: List[int]) -> Tuple[pd.DataFrame, bool]:
"""
Remove rows of numpy ndarrays containing NaN values on selected columns.
Remove rows of numpy arrays containing NaN values on selected columns.
Parameters
----------
data: Sequence[np.array]
Sequence containing numpy ndarrays.
Sequence containing numpy arrays.
selected_columns: List[int]
List containing the indices of column numbers
Expand All @@ -665,7 +665,7 @@ def _common_nan_removal_ndarrays(data: Sequence[np.array], selected_columns: Lis
"""
# Check if all selected_columns indices are valid for the first ndarray
if not all(col < len(data) for col in selected_columns):
raise ValueError(
raise InvalidArgumentsException(
f"Selected columns: {selected_columns} not all present in provided data columns with shape {data[0].shape}"
)

Expand All @@ -680,22 +680,24 @@ def _common_nan_removal_ndarrays(data: Sequence[np.array], selected_columns: Lis


@overload
def common_nan_removal(data: pd.DataFrame, selected_columns: List[str]) -> Tuple[pd.DataFrame, bool]: ...
def common_nan_removal(data: pd.DataFrame, selected_columns: List[str]) -> Tuple[pd.DataFrame, bool]:
...


@overload
def common_nan_removal(data: Sequence[np.array], selected_columns: List[int]) -> Tuple[pd.DataFrame, bool]: ...
def common_nan_removal(data: Sequence[np.array], selected_columns: List[int]) -> Tuple[pd.DataFrame, bool]:
...


def common_nan_removal(
data: Union[pd.DataFrame, Sequence[np.array]], selected_columns: Union[List[str], List[int]]
data: Union[pd.DataFrame, Sequence[np.array]], selected_columns: Union[List[str], List[int]]
) -> Tuple[pd.DataFrame, bool]:
"""
Wrapper function to handle both pandas DataFrame and sequences of numpy ndarrays.
Parameters
----------
data: Union[pd.DataFrame, Sequence[np.ndarray]]
data: Union[pd.DataFrame, Sequence[np.array]]
Pandas dataframe or sequence of numpy ndarrays containing data.
selected_columns: Union[List[str], List[int]]
List containing the column names or indices
Expand Down
59 changes: 18 additions & 41 deletions tests/test_base.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,24 @@

import numpy as np
import pandas as pd
import pytest

from nannyml.base import common_nan_removal
from nannyml.exceptions import InvalidArgumentsException


def test_common_nan_removal_dataframe():
data = pd.DataFrame({
'A': [1, 2, np.nan, 4],
'B': [5, np.nan, 7, 8],
'C': [9, 10, 11, np.nan]
})
data = pd.DataFrame({'A': [1, 2, np.nan, 4], 'B': [5, np.nan, 7, 8], 'C': [9, 10, 11, np.nan]})
selected_columns = ['A', 'B']
df_cleaned, is_empty = common_nan_removal(data, selected_columns)

expected_df = pd.DataFrame({
'A': [1, 4],
'B': [5, 8],
'C': [9, np.nan]
}).reset_index(drop=True)
expected_df = pd.DataFrame({'A': [1, 4], 'B': [5, 8], 'C': [9, np.nan]}).reset_index(drop=True)

pd.testing.assert_frame_equal(df_cleaned, expected_df, check_dtype=False) # ignore types because of infer_objects
assert not is_empty


def test_common_nan_removal_dataframe_all_nan():
data = pd.DataFrame({
'A': [np.nan, np.nan],
'B': [np.nan, np.nan],
'C': [np.nan, np.nan]
})
data = pd.DataFrame({'A': [np.nan, np.nan], 'B': [np.nan, np.nan], 'C': [np.nan, np.nan]})
selected_columns = ['A', 'B']
df_cleaned, is_empty = common_nan_removal(data, selected_columns)

Expand All @@ -40,23 +28,20 @@ def test_common_nan_removal_dataframe_all_nan():
assert is_empty


def test_common_nan_removal_ndarrays():
data = [
np.array([1, 5, 9]),
np.array([2, np.nan, 10]),
np.array([np.nan, 7, 11]),
np.array([4, 8, np.nan])
]
def test_common_nan_removal_arrays():
data = [np.array([1, 5, 9]), np.array([2, np.nan, 10]), np.array([np.nan, 7, 11]), np.array([4, 8, np.nan])]
selected_columns_indices = [0, 1] # Corresponds to columns 'A' and 'B'

df_cleaned, is_empty = common_nan_removal(data, selected_columns_indices)

expected_df = pd.DataFrame({
'col_0': [1, 9],
'col_1': [2, 10],
'col_2': [np.nan, 11],
'col_3': [4, np.nan],
}).reset_index(drop=True)
expected_df = pd.DataFrame(
{
'col_0': [1, 9],
'col_1': [2, 10],
'col_2': [np.nan, 11],
'col_3': [4, np.nan],
}
).reset_index(drop=True)

pd.testing.assert_frame_equal(df_cleaned, expected_df, check_dtype=False)
assert not is_empty
Expand All @@ -67,28 +52,21 @@ def test_common_nan_removal_arrays_all_nan():
np.array([np.nan, np.nan]),
np.array([np.nan, np.nan]),
np.array([np.nan, np.nan]),

]
selected_columns_indices = [0, 1] # Corresponds to columns 'A' and 'B'

df_cleaned, is_empty = common_nan_removal(data, selected_columns_indices)

expected_df = pd.DataFrame(columns=[
'col_0', 'col_1', 'col_2'
])
expected_df = pd.DataFrame(columns=['col_0', 'col_1', 'col_2'])

pd.testing.assert_frame_equal(df_cleaned, expected_df, check_index_type=False, check_dtype=False)
assert is_empty


def test_invalid_dataframe_columns():
data = pd.DataFrame({
'A': [1, 2, np.nan, 4],
'B': [5, np.nan, 7, 8],
'C': [9, 10, 11, np.nan]
})
data = pd.DataFrame({'A': [1, 2, np.nan, 4], 'B': [5, np.nan, 7, 8], 'C': [9, 10, 11, np.nan]})
selected_columns = ['A', 'D'] # 'D' does not exist
with pytest.raises(ValueError):
with pytest.raises(InvalidArgumentsException):
common_nan_removal(data, selected_columns)


Expand All @@ -97,9 +75,8 @@ def test_invalid_array_columns():
np.array([np.nan, np.nan]),
np.array([np.nan, np.nan]),
np.array([np.nan, np.nan]),

]
selected_columns_indices = [0, 3] # Index 3 does not exist in ndarray

with pytest.raises(ValueError):
with pytest.raises(InvalidArgumentsException):
common_nan_removal(data, selected_columns_indices)

0 comments on commit ba1d023

Please sign in to comment.