Skip to content

Commit 3166ed7

Browse files
authored
Add iterator(..., kind) to TypedTree (#18)
* Update take_the_tour.ipynb * Add TypedTree.iterator(..., kind)` and `TypedNode.iterator(..., kind)` * Add count_descendants(kind) * Update ruff to 0.9 * More tests
1 parent 1efa41a commit 3166ed7

12 files changed

+586
-486
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
repos:
22
- repo: https://github.com/astral-sh/ruff-pre-commit
33
# Ruff version.
4-
rev: v0.6.6
4+
rev: v0.9.5
55
hooks:
66
# Run the linter.
77
- id: ruff

CHANGELOG.md

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
# Changelog
22

3-
## 1.0.1 (unreleased)
3+
## 1.1.0 (unreleased)
44

5+
- DEPRECATE: `TypedTree.iter_by_type()`. Use `iterator(.., kind)`instead.
6+
- New methods `TypedTree.iterator(..., kind=ANY_KIND)`,
7+
`TypedNode.iterator(..., kind=ANY_KIND)`,
8+
and `TypedTree.count_descendants(leaves_only=False, kind=ANY_KIND)`
9+
510
## 1.0.0 (2024-12-27)
611
- Add benchmarks (using [Benchman](https://github.com/mar10/benchman)).
712
- Drop support for Python 3.8

Pipfile

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ pytest = "*"
1717
pytest-cov = "*"
1818
PyYAML = "*"
1919
rdflib = "*"
20-
ruff = "*"
20+
ruff = "~=0.9"
2121
setuptools = ">=42.0"
2222
Sphinx = "*"
2323
sphinx_rtd_theme = "*"

Pipfile.lock

+450-438
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/jupyter/take_the_tour.ipynb

+42-23
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
},
2323
{
2424
"cell_type": "code",
25-
"execution_count": 26,
25+
"execution_count": 1,
2626
"metadata": {},
2727
"outputs": [
2828
{
@@ -68,7 +68,7 @@
6868
},
6969
{
7070
"cell_type": "code",
71-
"execution_count": 27,
71+
"execution_count": 2,
7272
"metadata": {},
7373
"outputs": [],
7474
"source": [
@@ -103,7 +103,7 @@
103103
},
104104
{
105105
"cell_type": "code",
106-
"execution_count": 28,
106+
"execution_count": 3,
107107
"metadata": {},
108108
"outputs": [],
109109
"source": [
@@ -132,21 +132,21 @@
132132
},
133133
{
134134
"cell_type": "code",
135-
"execution_count": 29,
135+
"execution_count": 4,
136136
"metadata": {},
137137
"outputs": [
138138
{
139139
"name": "stdout",
140140
"output_type": "stream",
141141
"text": [
142142
"Tree<'Organization'>\n",
143-
"├── <__main__.Department object at 0x118c3fbc0>\n",
144-
"│ ├── <__main__.Department object at 0x118646c30>\n",
145-
"│ │ ╰── <__main__.Person object at 0x118c68170>\n",
146-
"│ ╰── <__main__.Person object at 0x118c3f770>\n",
147-
"├── <__main__.Department object at 0x118c3ffe0>\n",
148-
"│ ╰── <__main__.Person object at 0x118c68bf0>\n",
149-
"╰── <__main__.Person object at 0x118c3ffb0>\n"
143+
"├── <__main__.Department object at 0x106a66e40>\n",
144+
"│ ├── <__main__.Department object at 0x106a65610>\n",
145+
"│ │ ╰── <__main__.Person object at 0x10c02e4e0>\n",
146+
"│ ╰── <__main__.Person object at 0x10c02ede0>\n",
147+
"├── <__main__.Department object at 0x106a64170>\n",
148+
"│ ╰── <__main__.Person object at 0x10c02e5a0>\n",
149+
"╰── <__main__.Person object at 0x10c02e150>\n"
150150
]
151151
}
152152
],
@@ -183,7 +183,7 @@
183183
},
184184
{
185185
"cell_type": "code",
186-
"execution_count": 30,
186+
"execution_count": 5,
187187
"metadata": {},
188188
"outputs": [
189189
{
@@ -215,16 +215,16 @@
215215
},
216216
{
217217
"cell_type": "code",
218-
"execution_count": 31,
218+
"execution_count": 6,
219219
"metadata": {},
220220
"outputs": [
221221
{
222222
"data": {
223223
"text/plain": [
224-
"Node<'Person<Alice (25)>', data_id=294404091>"
224+
"Node<'Person<Alice (25)>', data_id=281030165>"
225225
]
226226
},
227-
"execution_count": 31,
227+
"execution_count": 6,
228228
"metadata": {},
229229
"output_type": "execute_result"
230230
}
@@ -243,16 +243,16 @@
243243
},
244244
{
245245
"cell_type": "code",
246-
"execution_count": 32,
246+
"execution_count": 7,
247247
"metadata": {},
248248
"outputs": [
249249
{
250250
"data": {
251251
"text/plain": [
252-
"<__main__.Person at 0x118c3ffb0>"
252+
"<__main__.Person at 0x10c02e150>"
253253
]
254254
},
255-
"execution_count": 32,
255+
"execution_count": 7,
256256
"metadata": {},
257257
"output_type": "execute_result"
258258
}
@@ -276,23 +276,42 @@
276276
},
277277
{
278278
"cell_type": "code",
279-
"execution_count": 33,
279+
"execution_count": 11,
280280
"metadata": {},
281281
"outputs": [
282282
{
283283
"data": {
284284
"text/plain": [
285-
"[Node<'Department<Development>', data_id=294404028>,\n",
286-
" Node<'Department<Marketing>', data_id=294404094>]"
285+
"'Department<Test>'"
287286
]
288287
},
289-
"execution_count": 33,
288+
"execution_count": 11,
289+
"metadata": {},
290+
"output_type": "execute_result"
291+
}
292+
],
293+
"source": [
294+
"str(tree[claire].parent.data)"
295+
]
296+
},
297+
{
298+
"cell_type": "code",
299+
"execution_count": 12,
300+
"metadata": {},
301+
"outputs": [
302+
{
303+
"data": {
304+
"text/plain": [
305+
"[Node<'Department<Development>', data_id=275408612>,\n",
306+
" Node<'Department<Marketing>', data_id=275407895>]"
307+
]
308+
},
309+
"execution_count": 12,
290310
"metadata": {},
291311
"output_type": "execute_result"
292312
}
293313
],
294314
"source": [
295-
"# tree[alice].parent.data\n",
296315
"tree[alice].get_siblings()"
297316
]
298317
},

docs/sphinx/ug_graphs.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ Navigation methods are type-aware now::
356356
assert cause1.get_index() == 0
357357
assert cause1.get_index(any_kind=True) == 2
358358

359-
assert len(list(tree.iter_by_type("effect"))) == 3
359+
assert len(list(tree.iterator(kind="effect"))) == 3
360360

361361
Keep in mind that a tree node is unique within a tree, but may reference identical
362362
data objects, so these `clones` could exist at different locations of tree.

nutree/tree.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1022,9 +1022,9 @@ def _self_check(self) -> Literal[True]:
10221022
# assert node._data_id == self.calc_data_id(node.data), node
10231023
assert node._data_id in self._nodes_by_data_id, node
10241024
assert node._node_id == id(node), f"{node}: {node._node_id} != {id(node)}"
1025-
assert (
1026-
node._children is None or len(node._children) > 0
1027-
), f"{node}: {node._children}"
1025+
assert node._children is None or len(node._children) > 0, (
1026+
f"{node}: {node._children}"
1027+
)
10281028

10291029
assert len(self._node_by_id) == len(node_list)
10301030

nutree/tree_generator.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ class Randomizer(ABC):
4343
"""
4444

4545
def __init__(self, *, probability: float = 1.0) -> None:
46-
assert (
47-
isinstance(probability, float) and 0.0 <= probability <= 1.0
48-
), f"probality must be in the range [0.0 .. 1.0]: {probability}"
46+
assert isinstance(probability, float) and 0.0 <= probability <= 1.0, (
47+
f"probality must be in the range [0.0 .. 1.0]: {probability}"
48+
)
4949
self.probability = probability
5050

5151
def _skip_value(self) -> bool:
@@ -84,9 +84,9 @@ def __init__(
8484
none_value: Any = None,
8585
) -> None:
8686
super().__init__(probability=probability)
87-
assert type(min_val) is type(
88-
max_val
89-
), f"min_val and max_val must be of the same type: {min_val}, {max_val}"
87+
assert type(min_val) is type(max_val), (
88+
f"min_val and max_val must be of the same type: {min_val}, {max_val}"
89+
)
9090
self.is_float = isinstance(min_val, float)
9191
self.min = min_val
9292
self.max = max_val
@@ -129,19 +129,19 @@ def __init__(
129129
) -> None:
130130
super().__init__(probability=probability)
131131
assert isinstance(min_dt, date), f"min_dt must be a date: {min_dt}"
132-
assert isinstance(
133-
max_dt, (date, int)
134-
), f"max_dt must be a date or int: {max_dt}"
132+
assert isinstance(max_dt, (date, int)), (
133+
f"max_dt must be a date or int: {max_dt}"
134+
)
135135

136136
if isinstance(max_dt, int):
137137
self.delta_days = max_dt
138138
max_dt = min_dt + timedelta(days=self.delta_days)
139139
else:
140140
self.delta_days = (max_dt - min_dt).days
141141

142-
assert (
143-
max_dt > min_dt
144-
), f"max_dt must be greater than min_dt: {min_dt}, {max_dt}"
142+
assert max_dt > min_dt, (
143+
f"max_dt must be greater than min_dt: {min_dt}, {max_dt}"
144+
)
145145

146146
self.min = min_dt
147147
self.max = max_dt

nutree/typed_tree.py

+52-2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
ROOT_NODE_ID,
2828
DataIdType,
2929
DeserializeMapperType,
30+
IterMethod,
3031
KeyMapType,
3132
MapperCallbackType,
3233
PredicateCallbackType,
@@ -137,12 +138,44 @@ def last_child(self, kind: str | type[ANY_KIND]) -> Self | None:
137138
return n
138139
return None
139140

141+
def iterator(
142+
self,
143+
method: IterMethod = IterMethod.PRE_ORDER,
144+
*,
145+
add_self=False,
146+
kind: str | type[ANY_KIND] = ANY_KIND,
147+
) -> Iterator[TypedNode[TData]]:
148+
"""Return an iterator that walks the tree in the specified order."""
149+
if kind is ANY_KIND:
150+
yield from super().iterator(method=method, add_self=add_self)
151+
return
152+
153+
if add_self and self.kind == kind:
154+
yield self
155+
for n in super().iterator(method=method, add_self=False):
156+
if n.kind == kind:
157+
yield n
158+
return
159+
140160
def has_children(self, kind: str | type[ANY_KIND]) -> bool:
141161
"""Return true if this node has one or more children."""
142162
if kind is ANY_KIND:
143163
return bool(self._children)
144164
return len(self.get_children(kind)) > 1
145165

166+
def count_descendants(
167+
self, *, leaves_only=False, kind: str | type[ANY_KIND] = ANY_KIND
168+
) -> int:
169+
"""Return number of descendant nodes, not counting self."""
170+
if kind is ANY_KIND:
171+
return super().count_descendants(leaves_only=leaves_only)
172+
all = not leaves_only
173+
i = 0
174+
for node in self.iterator():
175+
if (all or not node._children) and node.kind == kind:
176+
i += 1
177+
return i
178+
146179
def get_siblings(self, *, add_self=False, any_kind=False) -> list[Self]:
147180
"""Return a list of all sibling entries of self (excluding self) if any."""
148181
if any_kind:
@@ -630,13 +663,30 @@ def last_child(self, kind: str | type[ANY_KIND]) -> TypedNode[TData] | None:
630663
return self.system_root.last_child(kind=kind)
631664

632665
def iter_by_type(self, kind: str | type[ANY_KIND]) -> Iterator[TypedNode[TData]]:
666+
"""@deprecated: Use :meth:`iterator` with `kind` argument instead."""
667+
yield from self.iterator(kind=kind)
668+
669+
def iterator(
670+
self,
671+
method: IterMethod = IterMethod.PRE_ORDER,
672+
*,
673+
kind: str | type[ANY_KIND] = ANY_KIND,
674+
) -> Iterator[TypedNode[TData]]:
633675
if kind == ANY_KIND:
634-
yield from self.iterator()
635-
for n in self.iterator():
676+
yield from super().iterator(method=method)
677+
return
678+
679+
for n in super().iterator(method=method):
636680
if n._kind == kind:
637681
yield n
638682
return
639683

684+
def count_descendants(
685+
self, *, leaves_only=False, kind: str | type[ANY_KIND] = ANY_KIND
686+
) -> int:
687+
"""Return number of nodes, optionally restricted to type."""
688+
return self.system_root.count_descendants(leaves_only=leaves_only, kind=kind)
689+
640690
def save(
641691
self,
642692
target: IO[str] | str | Path,

tests/test_objects.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -285,9 +285,9 @@ class FrozenItem:
285285

286286
assert isinstance(dict_node.data, FrozenItem)
287287
assert dict_node.data is item, "dataclass should be stored as reference"
288-
assert (
289-
dict_node.price == 12.34
290-
), "should support attribute access via forwardinging"
288+
assert dict_node.price == 12.34, (
289+
"should support attribute access via forwardinging"
290+
)
291291
with pytest.raises(AttributeError):
292292
_ = dict_node.foo
293293

tests/test_rdf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def test_typed_tree(self):
6464
assert cause1.get_index() == 0
6565
assert cause1.get_index(any_kind=True) == 2
6666

67-
assert len(list(tree.iter_by_type("effect"))) == 3
67+
assert len(list(tree.iterator(kind="effect"))) == 3
6868

6969
# tree.print()
7070
# print()

tests/test_typed_tree.py

+14
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,21 @@ def test_add_child(self):
9595

9696
assert cause2.parent is fail1
9797

98+
assert tree.count == 8
99+
assert tree.count_descendants() == 8
100+
tree.print()
101+
assert tree.count_descendants(leaves_only=True) == 6
102+
assert tree.count_descendants(kind="cause") == 2
103+
assert tree.count_descendants(leaves_only=True, kind="failure") == 1
104+
assert tree.system_root.count_descendants(kind="failure") == 2
105+
98106
assert len(list(tree.iter_by_type("cause"))) == 2
107+
assert len(list(tree.iterator(kind="cause"))) == 2
108+
assert len(list(tree.iterator())) == 8
109+
110+
assert len(list(tree.system_root.iterator(add_self=True))) == 9
111+
assert len(list(tree.system_root.iterator(kind="cause"))) == 2
112+
assert len(list(tree.system_root.iterator(kind="cause", add_self=True))) == 2
99113

100114
assert cause2.get_children("undefined") == []
101115

0 commit comments

Comments
 (0)