diff --git a/bindings/experimental/distrdf/python/DistRDF/HeadNode.py b/bindings/experimental/distrdf/python/DistRDF/HeadNode.py index b89d9d721a33f..ab7cc9cdde7fb 100644 --- a/bindings/experimental/distrdf/python/DistRDF/HeadNode.py +++ b/bindings/experimental/distrdf/python/DistRDF/HeadNode.py @@ -265,9 +265,8 @@ def execute_graph(self) -> None: def get_headnode(backend: BaseBackend, npartitions: int, *args) -> HeadNode: """ A factory for different kinds of head nodes of the RDataFrame computation - graph, depending on the arguments to the RDataFrame constructor. Currently - can return a TreeHeadNode or an EmptySourceHeadNode. Parses the arguments and - compares them against the possible RDataFrame constructors. + graph, depending on the arguments to the RDataFrame constructor. Parses the + arguments and compares them against the possible RDataFrame constructors. """ # Early check that arguments are accepted by RDataFrame @@ -276,60 +275,23 @@ def get_headnode(backend: BaseBackend, npartitions: int, *args) -> HeadNode: except TypeError: raise TypeError(("The arguments provided are not accepted by any RDataFrame constructor. " "See the RDataFrame documentation for the accepted constructor argument types.")) - - firstarg = args[0] - if isinstance(firstarg, int): - # RDataFrame(ULong64_t numEntries) - return EmptySourceHeadNode(backend, npartitions, localdf, firstarg) - elif isinstance(firstarg, (ROOT.TTree)): - # # RDataFrame(std::string_view treeName, filenameglob, defaultBranches = {}) - # # RDataFrame(std::string_view treename, filenames, defaultBranches = {}) - # # RDataFrame(std::string_view treeName, dirPtr, defaultBranches = {}) - # # RDataFrame(TTree &tree, const ColumnNames_t &defaultBranches = {}) - return TreeHeadNode(backend, npartitions, localdf, *args) - elif isinstance(firstarg, str): - # if first argument is a string, we want to check if the second argument - # which is a ROOT file holds a TTree or RNTuple - secondarg = args[1] - - if (isinstance (secondarg, (str, ROOT.std.string, ROOT.std.string_view))): - wildcards = ["[", "]", "*", "?"] - if (any(wildcard in secondarg for wildcard in wildcards)): - path = ROOT.Internal.TreeUtils.ExpandGlob(secondarg)[0] - else: - path = secondarg - else: - path = secondarg[0] - - with ROOT.TDirectory.TContext(), ROOT.TFile.Open(path, "READ_WITHOUT_GLOBALREGISTRATION") as f: - dataset = f.Get(firstarg) - try: - if isinstance(dataset, ROOT.TTree): - return TreeHeadNode(backend, npartitions, localdf, *args) - - elif isinstance(dataset, ROOT.Experimental.RNTuple): - return RNTupleHeadNode(backend, npartitions, localdf, *args) - - else: - raise RuntimeError("") - finally: - # Remove the reference to the object taken from the TFile - # This helps avoiding conflicts between the Python garbage - # collector, the cppyy memory regulator and the C++ object. - # TODO: Rewrite this whole section as a C++ function. Idea: - # the function should be a method of RInterfaceBase so that - # from the localdf object we could get a data source label and - # dispatch to the correct HeadNode type without reopening. - del dataset - elif isinstance(firstarg, ROOT.RDF.Experimental.RDatasetSpec): - # RDataFrame(rdatasetspec) - return RDatasetSpecHeadNode(backend, npartitions, localdf, *args) + + if isinstance(args[0], ROOT.RDF.Experimental.RDatasetSpec): + return RDatasetSpecHeadNode(backend, npartitions, localdf, *args) + + label = ROOT.Internal.RDF.GetDataSourceLabel(ROOT.RDF.AsRNode(localdf)) + if label == "TTreeDS": + return TreeHeadNode(backend, npartitions, localdf, *args) + elif label == "RNTupleDS": + return RNTupleHeadNode(backend, npartitions, localdf, *args) + elif label == "EmptyDS": + return EmptySourceHeadNode(backend, npartitions, localdf, args[0]) else: raise RuntimeError( - ("First argument {} of type {} is not recognised as a supported " - "argument for distributed RDataFrame. Currently only TTree/Tchain " - "based datasets or datasets created from a number of entries " - "can be processed distributedly.").format(firstarg, type(firstarg))) + (f"First argument {args[0]} of type {type(args[0])} is not " + "recognised as a supported argument for distributed RDataFrame. " + "Currently supported data sources are: TTree, RNTuple or an empty " + "data source.")) class EmptySourceHeadNode(HeadNode):