25
25
from ... import opcodes as OperandDef
26
26
from ...config import options
27
27
from ...core .custom_log import redirect_custom_log
28
- from ...core import ENTITY_TYPE , OutputType
28
+ from ...core import ENTITY_TYPE , OutputType , recursive_tile
29
29
from ...core .context import get_context
30
30
from ...core .operand import OperandStage
31
31
from ...serialization .serializables import (
64
64
65
65
_support_get_group_without_as_index = pd_release_version [:2 ] > (1 , 0 )
66
66
67
+ _FUNCS_PREFER_SHUFFLE = {"nunique" }
68
+
67
69
68
70
class SizeRecorder :
69
71
def __init__ (self ):
@@ -163,6 +165,8 @@ class DataFrameGroupByAgg(DataFrameOperand, DataFrameOperandMixin):
163
165
method = StringField ("method" )
164
166
use_inf_as_na = BoolField ("use_inf_as_na" )
165
167
168
+ map_on_shuffle = AnyField ("map_on_shuffle" )
169
+
166
170
# for chunk
167
171
combine_size = Int32Field ("combine_size" )
168
172
chunk_store_limit = Int64Field ("chunk_store_limit" )
@@ -421,10 +425,29 @@ def _tile_with_shuffle(
421
425
in_df : TileableType ,
422
426
out_df : TileableType ,
423
427
func_infos : ReductionSteps ,
428
+ agg_chunks : List [ChunkType ] = None ,
424
429
):
425
- # First, perform groupby and aggregation on each chunk.
426
- agg_chunks = cls ._gen_map_chunks (op , in_df .chunks , out_df , func_infos )
427
- return cls ._perform_shuffle (op , agg_chunks , in_df , out_df , func_infos )
430
+ if op .map_on_shuffle is None :
431
+ op .map_on_shuffle = all (
432
+ agg_fun .custom_reduction is None for agg_fun in func_infos .agg_funcs
433
+ )
434
+
435
+ if not op .map_on_shuffle :
436
+ groupby_params = op .groupby_params .copy ()
437
+ selection = groupby_params .pop ("selection" , None )
438
+ groupby = in_df .groupby (** groupby_params )
439
+ if selection :
440
+ groupby = groupby [selection ]
441
+ result = groupby .transform (
442
+ op .raw_func , _call_agg = True , index = out_df .index_value
443
+ )
444
+ return (yield from recursive_tile (result ))
445
+ else :
446
+ # First, perform groupby and aggregation on each chunk.
447
+ agg_chunks = agg_chunks or cls ._gen_map_chunks (
448
+ op , in_df .chunks , out_df , func_infos
449
+ )
450
+ return cls ._perform_shuffle (op , agg_chunks , in_df , out_df , func_infos )
428
451
429
452
@classmethod
430
453
def _perform_shuffle (
@@ -624,8 +647,10 @@ def _tile_auto(
624
647
else :
625
648
# otherwise, use shuffle
626
649
logger .debug ("Choose shuffle method for groupby operand %s" , op )
627
- return cls ._perform_shuffle (
628
- op , chunks + left_chunks , in_df , out_df , func_infos
650
+ return (
651
+ yield from cls ._tile_with_shuffle (
652
+ op , in_df , out_df , func_infos , chunks + left_chunks
653
+ )
629
654
)
630
655
631
656
@classmethod
@@ -638,12 +663,16 @@ def tile(cls, op: "DataFrameGroupByAgg"):
638
663
func_infos = cls ._compile_funcs (op , in_df )
639
664
640
665
if op .method == "auto" :
641
- if len (in_df .chunks ) <= op .combine_size :
666
+ if set (op .func ) & _FUNCS_PREFER_SHUFFLE :
667
+ return (
668
+ yield from cls ._tile_with_shuffle (op , in_df , out_df , func_infos )
669
+ )
670
+ elif len (in_df .chunks ) <= op .combine_size :
642
671
return cls ._tile_with_tree (op , in_df , out_df , func_infos )
643
672
else :
644
673
return (yield from cls ._tile_auto (op , in_df , out_df , func_infos ))
645
674
if op .method == "shuffle" :
646
- return cls ._tile_with_shuffle (op , in_df , out_df , func_infos )
675
+ return ( yield from cls ._tile_with_shuffle (op , in_df , out_df , func_infos ) )
647
676
elif op .method == "tree" :
648
677
return cls ._tile_with_tree (op , in_df , out_df , func_infos )
649
678
else : # pragma: no cover
@@ -1075,7 +1104,15 @@ def execute(cls, ctx, op: "DataFrameGroupByAgg"):
1075
1104
pd .reset_option ("mode.use_inf_as_na" )
1076
1105
1077
1106
1078
- def agg (groupby , func = None , method = "auto" , combine_size = None , * args , ** kwargs ):
1107
+ def agg (
1108
+ groupby ,
1109
+ func = None ,
1110
+ method = "auto" ,
1111
+ combine_size = None ,
1112
+ map_on_shuffle = None ,
1113
+ * args ,
1114
+ ** kwargs ,
1115
+ ):
1079
1116
"""
1080
1117
Aggregate using one or more operations on grouped data.
1081
1118
@@ -1091,7 +1128,11 @@ def agg(groupby, func=None, method="auto", combine_size=None, *args, **kwargs):
1091
1128
in distributed mode and use 'tree' in local mode.
1092
1129
combine_size : int
1093
1130
The number of chunks to combine when method is 'tree'
1094
-
1131
+ map_on_shuffle : bool
1132
+ When not specified, will decide whether to perform aggregation on the
1133
+ map stage of shuffle (currently no aggregation when there is custom
1134
+ reduction in functions). Otherwise, whether to call map on map stage
1135
+ of shuffle is determined by the value.
1095
1136
1096
1137
Returns
1097
1138
-------
@@ -1138,5 +1179,6 @@ def agg(groupby, func=None, method="auto", combine_size=None, *args, **kwargs):
1138
1179
combine_size = combine_size or options .combine_size ,
1139
1180
chunk_store_limit = options .chunk_store_limit ,
1140
1181
use_inf_as_na = use_inf_as_na ,
1182
+ map_on_shuffle = map_on_shuffle ,
1141
1183
)
1142
1184
return agg_op (groupby )
0 commit comments