Skip to content

Commit 22ca9e4

Browse files
committed
Fix apply
1 parent c43918d commit 22ca9e4

File tree

4 files changed

+13
-2
lines changed

4 files changed

+13
-2
lines changed

mars/dataframe/groupby/aggregation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def get(self):
9494
"skew": lambda x, bias=False: x.skew(bias=bias),
9595
"kurt": lambda x, bias=False: x.kurt(bias=bias),
9696
"kurtosis": lambda x, bias=False: x.kurtosis(bias=bias),
97+
"nunique": lambda x: x.nunique(),
9798
}
9899
_series_col_name = "col_name"
99100

@@ -720,7 +721,8 @@ def _do_custom_agg(op, custom_reduction, *input_objs):
720721
result = (result,)
721722

722723
if out.ndim == 2:
723-
result = tuple(r.to_frame().T for r in result)
724+
if result[0].ndim == 1:
725+
result = tuple(r.to_frame().T for r in result)
724726
if op.stage == OperandStage.agg:
725727
result = tuple(r.astype(out.dtypes) for r in result)
726728
else:

mars/dataframe/reduction/aggregation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def where_function(cond, var1, var2):
7878
"skew": lambda x, skipna=True, bias=False: x.skew(skipna=skipna, bias=bias),
7979
"kurt": lambda x, skipna=True, bias=False: x.kurt(skipna=skipna, bias=bias),
8080
"kurtosis": lambda x, skipna=True, bias=False: x.kurtosis(skipna=skipna, bias=bias),
81+
"nunique": lambda x: x.nunique(),
8182
}
8283

8384

mars/dataframe/reduction/core.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -972,13 +972,15 @@ def _compile_function(self, func, func_name=None, ndim=1) -> ReductionSteps:
972972
else:
973973
map_func_name, agg_func_name = step_func_name, step_func_name
974974

975+
op_custom_reduction = getattr(t.op, "custom_reduction", None)
976+
975977
# build agg description
976978
agg_funcs.append(
977979
ReductionAggStep(
978980
agg_input_key,
979981
map_func_name,
980982
agg_func_name,
981-
custom_reduction,
983+
op_custom_reduction or custom_reduction,
982984
t.key,
983985
output_limit,
984986
t.op.get_reduction_args(axis=self._axis),

mars/dataframe/reduction/tests/test_reduction_execution.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,12 @@ def test_nunique(setup, check_ref_counts):
671671
expected = data1.nunique(axis=1)
672672
pd.testing.assert_series_equal(result, expected)
673673

674+
# test with agg func
675+
df = md.DataFrame(data1, chunk_size=3)
676+
result = df.agg("nunique").execute().fetch()
677+
expected = data1.agg("nunique")
678+
pd.testing.assert_series_equal(result, expected)
679+
674680

675681
@pytest.mark.skipif(pa is None, reason="pyarrow not installed")
676682
def test_use_arrow_dtype_nunique(setup, check_ref_counts):

0 commit comments

Comments
 (0)