Skip to content

Commit 9cf17b9

Browse files
committed
Improve match subject inference
1 parent 50fc847 commit 9cf17b9

File tree

4 files changed

+172
-28
lines changed

4 files changed

+172
-28
lines changed

mypy/checker.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5567,7 +5567,9 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
55675567
# capture variable may depend on multiple patterns (it
55685568
# will be a union of all capture types). This pass ignores
55695569
# guard expressions.
5570-
pattern_types = [self.pattern_checker.accept(p, subject_type) for p in s.patterns]
5570+
pattern_types = [
5571+
self.pattern_checker.accept(p, subject_type, [s.subject]) for p in s.patterns
5572+
]
55715573
type_maps: list[TypeMap] = [t.captures for t in pattern_types]
55725574
inferred_types = self.infer_variable_types_from_type_maps(type_maps)
55735575

@@ -5577,7 +5579,7 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
55775579
current_subject_type = self.expr_checker.narrow_type_from_binder(
55785580
named_subject, subject_type
55795581
)
5580-
pattern_type = self.pattern_checker.accept(p, current_subject_type)
5582+
pattern_type = self.pattern_checker.accept(p, current_subject_type, [s.subject])
55815583
with self.binder.frame_context(can_skip=True, fall_through=2):
55825584
if b.is_unreachable or isinstance(
55835585
get_proper_type(pattern_type.type), UninhabitedType

mypy/checkpattern.py

Lines changed: 84 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,23 @@
1010
from mypy.checkmember import analyze_member_access
1111
from mypy.expandtype import expand_type_by_instance
1212
from mypy.join import join_types
13-
from mypy.literals import literal_hash
13+
from mypy.literals import Key, literal_hash
1414
from mypy.maptype import map_instance_to_supertype
1515
from mypy.meet import narrow_declared_type
1616
from mypy.messages import MessageBuilder
17-
from mypy.nodes import ARG_POS, Context, Expression, NameExpr, TypeAlias, TypeInfo, Var
17+
from mypy.nodes import (
18+
ARG_POS,
19+
Context,
20+
Expression,
21+
IndexExpr,
22+
IntExpr,
23+
MemberExpr,
24+
NameExpr,
25+
TypeAlias,
26+
TypeInfo,
27+
UnaryExpr,
28+
Var,
29+
)
1830
from mypy.options import Options
1931
from mypy.patterns import (
2032
AsPattern,
@@ -96,10 +108,8 @@ class PatternChecker(PatternVisitor[PatternType]):
96108
msg: MessageBuilder
97109
# Currently unused
98110
plugin: Plugin
99-
# The expression being matched against the pattern
100-
subject: Expression
101-
102-
subject_type: Type
111+
# The expressions being matched against the (sub)pattern
112+
subject_context: list[list[Expression]]
103113
# Type of the subject to check the (sub)pattern against
104114
type_context: list[Type]
105115
# Types that match against self instead of their __match_args__ if used as a class pattern
@@ -118,24 +128,28 @@ def __init__(
118128
self.msg = msg
119129
self.plugin = plugin
120130

131+
self.subject_context = []
121132
self.type_context = []
122133
self.self_match_types = self.generate_types_from_names(self_match_type_names)
123134
self.non_sequence_match_types = self.generate_types_from_names(
124135
non_sequence_match_type_names
125136
)
126137
self.options = options
127138

128-
def accept(self, o: Pattern, type_context: Type) -> PatternType:
139+
def accept(self, o: Pattern, type_context: Type, subject: list[Expression]) -> PatternType:
140+
self.subject_context.append(subject)
129141
self.type_context.append(type_context)
130142
result = o.accept(self)
143+
self.subject_context.pop()
131144
self.type_context.pop()
132145

133146
return result
134147

135148
def visit_as_pattern(self, o: AsPattern) -> PatternType:
149+
current_subject = self.subject_context[-1]
136150
current_type = self.type_context[-1]
137151
if o.pattern is not None:
138-
pattern_type = self.accept(o.pattern, current_type)
152+
pattern_type = self.accept(o.pattern, current_type, current_subject)
139153
typ, rest_type, type_map = pattern_type
140154
else:
141155
typ, rest_type, type_map = current_type, UninhabitedType(), {}
@@ -150,14 +164,15 @@ def visit_as_pattern(self, o: AsPattern) -> PatternType:
150164
return PatternType(typ, rest_type, type_map)
151165

152166
def visit_or_pattern(self, o: OrPattern) -> PatternType:
167+
current_subject = self.subject_context[-1]
153168
current_type = self.type_context[-1]
154169

155170
#
156171
# Check all the subpatterns
157172
#
158-
pattern_types = []
173+
pattern_types: list[PatternType] = []
159174
for pattern in o.patterns:
160-
pattern_type = self.accept(pattern, current_type)
175+
pattern_type = self.accept(pattern, current_type, current_subject)
161176
pattern_types.append(pattern_type)
162177
if not is_uninhabited(pattern_type.type):
163178
current_type = pattern_type.rest_type
@@ -173,28 +188,40 @@ def visit_or_pattern(self, o: OrPattern) -> PatternType:
173188
#
174189
# Check the capture types
175190
#
176-
capture_types: dict[Var, list[tuple[Expression, Type]]] = defaultdict(list)
191+
capture_types: dict[Var, dict[Key | None, list[tuple[Expression, Type]]]] = defaultdict(
192+
lambda: defaultdict(list)
193+
)
194+
capture_expr_keys: set[Key | None] = set()
177195
# Collect captures from the first subpattern
178196
for expr, typ in pattern_types[0].captures.items():
179197
node = get_var(expr)
180-
capture_types[node].append((expr, typ))
198+
key = literal_hash(expr)
199+
capture_types[node][key].append((expr, typ))
200+
if isinstance(expr, NameExpr):
201+
capture_expr_keys.add(key)
181202

182203
# Check if other subpatterns capture the same names
183204
for i, pattern_type in enumerate(pattern_types[1:]):
184-
vars = {get_var(expr) for expr, _ in pattern_type.captures.items()}
185-
if capture_types.keys() != vars:
205+
vars = {
206+
literal_hash(expr) for expr in pattern_type.captures if isinstance(expr, NameExpr)
207+
}
208+
if capture_expr_keys != vars:
209+
# Only fail for directly captured names (with NameExpr)
186210
self.msg.fail(message_registry.OR_PATTERN_ALTERNATIVE_NAMES, o.patterns[i])
187211
for expr, typ in pattern_type.captures.items():
188212
node = get_var(expr)
189-
capture_types[node].append((expr, typ))
213+
key = literal_hash(expr)
214+
capture_types[node][key].append((expr, typ))
190215

191216
captures: dict[Expression, Type] = {}
192-
for capture_list in capture_types.values():
193-
typ = UninhabitedType()
194-
for _, other in capture_list:
195-
typ = make_simplified_union([typ, other])
217+
for expressions in capture_types.values():
218+
for key, capture_list in expressions.items():
219+
if other_types := [entry[1] for entry in capture_list]:
220+
typ = make_simplified_union(other_types)
221+
else:
222+
typ = UninhabitedType()
196223

197-
captures[capture_list[0][0]] = typ
224+
captures[capture_list[0][0]] = typ
198225

199226
union_type = make_simplified_union(types)
200227
return PatternType(union_type, current_type, captures)
@@ -284,12 +311,24 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
284311
contracted_inner_types = self.contract_starred_pattern_types(
285312
inner_types, star_position, required_patterns
286313
)
287-
for p, t in zip(o.patterns, contracted_inner_types):
288-
pattern_type = self.accept(p, t)
314+
current_subjects: list[list[Expression]] = [[] for _ in range(len(contracted_inner_types))]
315+
for s in self.subject_context[-1]:
316+
# Support x[0], x[1], ... lookup until wildcard
317+
end_pos = len(contracted_inner_types) if star_position is None else star_position
318+
for i in range(end_pos):
319+
current_subjects[i].append(IndexExpr(s, IntExpr(i)))
320+
# For everything after wildcard use x[-2], x[-1]
321+
for i in range((star_position or -1) + 1, len(contracted_inner_types)):
322+
offset = len(contracted_inner_types) - i
323+
current_subjects[i].append(IndexExpr(s, UnaryExpr("-", IntExpr(offset))))
324+
for p, t, s in zip(o.patterns, contracted_inner_types, current_subjects):
325+
pattern_type = self.accept(p, t, s)
289326
typ, rest, type_map = pattern_type
290327
contracted_new_inner_types.append(typ)
291328
contracted_rest_inner_types.append(rest)
292329
self.update_type_map(captures, type_map)
330+
if s:
331+
self.update_type_map(captures, {subject: typ for subject in s})
293332

294333
new_inner_types = self.expand_starred_pattern_types(
295334
contracted_new_inner_types, star_position, len(inner_types), unpack_index is not None
@@ -473,11 +512,18 @@ def visit_mapping_pattern(self, o: MappingPattern) -> PatternType:
473512
if inner_type is None:
474513
can_match = False
475514
inner_type = self.chk.named_type("builtins.object")
476-
pattern_type = self.accept(value, inner_type)
515+
current_subjects: list[Expression] = [
516+
IndexExpr(s, key) for s in self.subject_context[-1]
517+
]
518+
pattern_type = self.accept(value, inner_type, current_subjects)
477519
if is_uninhabited(pattern_type.type):
478520
can_match = False
479521
else:
480522
self.update_type_map(captures, pattern_type.captures)
523+
if current_subjects:
524+
self.update_type_map(
525+
captures, {subject: pattern_type.type for subject in current_subjects}
526+
)
481527

482528
if o.rest is not None:
483529
mapping = self.chk.named_type("typing.Mapping")
@@ -581,7 +627,7 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType:
581627
if self.should_self_match(typ):
582628
if len(o.positionals) > 1:
583629
self.msg.fail(message_registry.CLASS_PATTERN_TOO_MANY_POSITIONAL_ARGS, o)
584-
pattern_type = self.accept(o.positionals[0], narrowed_type)
630+
pattern_type = self.accept(o.positionals[0], narrowed_type, [])
585631
if not is_uninhabited(pattern_type.type):
586632
return PatternType(
587633
pattern_type.type,
@@ -681,11 +727,20 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType:
681727
elif keyword is not None:
682728
new_type = self.chk.add_any_attribute_to_type(new_type, keyword)
683729

684-
inner_type, inner_rest_type, inner_captures = self.accept(pattern, key_type)
730+
current_subjects: list[Expression] = []
731+
if keyword is not None:
732+
current_subjects = [MemberExpr(s, keyword) for s in self.subject_context[-1]]
733+
inner_type, inner_rest_type, inner_captures = self.accept(
734+
pattern, key_type, current_subjects
735+
)
685736
if is_uninhabited(inner_type):
686737
can_match = False
687738
else:
688739
self.update_type_map(captures, inner_captures)
740+
if current_subjects:
741+
self.update_type_map(
742+
captures, {subject: inner_type for subject in current_subjects}
743+
)
689744
if not is_uninhabited(inner_rest_type):
690745
rest_type = current_type
691746

@@ -799,6 +854,10 @@ def get_var(expr: Expression) -> Var:
799854
Warning: this in only true for expressions captured by a match statement.
800855
Don't call it from anywhere else
801856
"""
857+
if isinstance(expr, MemberExpr):
858+
return get_var(expr.expr)
859+
if isinstance(expr, IndexExpr):
860+
return get_var(expr.base)
802861
assert isinstance(expr, NameExpr), expr
803862
node = expr.node
804863
assert isinstance(node, Var), node

mypy/literals.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def visit_set_expr(self, e: SetExpr) -> Key | None:
228228
return self.seq_expr(e, "Set")
229229

230230
def visit_index_expr(self, e: IndexExpr) -> Key | None:
231-
if literal(e.index) == LITERAL_YES:
231+
if literal(e.index) != LITERAL_NO:
232232
return ("Index", literal_hash(e.base), literal_hash(e.index))
233233
return None
234234

test-data/unit/check-python310.test

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2965,3 +2965,86 @@ val: int = 8
29652965
match val:
29662966
case FOO: # E: Cannot assign to final name "FOO"
29672967
pass
2968+
2969+
[case testMatchSubjectInferenceSequence]
2970+
m: object
2971+
2972+
match m:
2973+
case [1, True]:
2974+
reveal_type(m) # N: Revealed type is "typing.Sequence[builtins.int]"
2975+
reveal_type(m[0]) # N: Revealed type is "Literal[1]"
2976+
reveal_type(m[-2]) # N: Revealed type is "Literal[1]"
2977+
reveal_type(m[1]) # N: Revealed type is "Literal[True]"
2978+
reveal_type(m[-1]) # N: Revealed type is "Literal[True]"
2979+
case [1, *_, False]:
2980+
reveal_type(m) # N: Revealed type is "typing.Sequence[builtins.object]"
2981+
reveal_type(m[-1]) # N: Revealed type is "Literal[False]"
2982+
case [[1], [True]]:
2983+
reveal_type(m[0][0]) # N: Revealed type is "Literal[1]"
2984+
reveal_type(m[-2][0]) # N: Revealed type is "Literal[1]"
2985+
reveal_type(m[1][0]) # N: Revealed type is "Literal[True]"
2986+
reveal_type(m[-1][0]) # N: Revealed type is "Literal[True]"
2987+
[builtins fixtures/tuple.pyi]
2988+
2989+
[case testMatchSubjectInferenceMapping]
2990+
from typing import Any
2991+
m: Any
2992+
2993+
match m:
2994+
case {"key": 1}:
2995+
reveal_type(m["key"]) # N: Revealed type is "Literal[1]"
2996+
2997+
[case testMatchSubjectInferenceClass]
2998+
from typing import Final
2999+
3000+
class A:
3001+
__match_args__: Final = ("a", "b")
3002+
a: str | None
3003+
b: int | None
3004+
3005+
m: A
3006+
3007+
match m:
3008+
case A("Hello", 2):
3009+
reveal_type(m.a) # N: Revealed type is "Literal['Hello']"
3010+
reveal_type(m.b) # N: Revealed type is "Literal[2]"
3011+
case A(a="Hello", b=2):
3012+
reveal_type(m.a) # N: Revealed type is "Literal['Hello']"
3013+
reveal_type(m.b) # N: Revealed type is "Literal[2]"
3014+
case A(a=str()) | A(a=None):
3015+
reveal_type(m.a) # N: Revealed type is "Union[builtins.str, None]"
3016+
case object(some_attr=str()):
3017+
reveal_type(m.some_attr) # N: Revealed type is "builtins.str"
3018+
[builtins fixtures/tuple.pyi]
3019+
3020+
[case testMatchSubjectInferenceOR]
3021+
m: object
3022+
3023+
match m:
3024+
case [1, 2, 3] | [8, 9]:
3025+
reveal_type(m[0]) # N: Revealed type is "Union[Literal[1], Literal[8]]"
3026+
reveal_type(m[1]) # N: Revealed type is "Union[Literal[2], Literal[9]]"
3027+
reveal_type(m[2]) # N: Revealed type is "Literal[3]"
3028+
3029+
[case testMatchSubjectNested]
3030+
from typing import Any
3031+
class A:
3032+
a: str | None
3033+
b: int | None
3034+
3035+
m: Any
3036+
3037+
match m:
3038+
case {"key": [0, A(a="Hello")]}:
3039+
reveal_type(m) # N: Revealed type is "Any"
3040+
reveal_type(m["key"]) # N: Revealed type is "Any"
3041+
reveal_type(m["key"][0]) # N: Revealed type is "Literal[0]"
3042+
reveal_type(m["key"][1]) # N: Revealed type is "__main__.A"
3043+
reveal_type(m["key"][1].a) # N: Revealed type is "Literal['Hello']"
3044+
case [0, {"key": 2}]:
3045+
reveal_type(m[1]) # N: Revealed type is "Any"
3046+
reveal_type(m[1]["key"]) # N: Revealed type is "Literal[2]"
3047+
case object(a=[A(a="Hello") | A(a="World")]):
3048+
reveal_type(m.a) # N: Revealed type is "Any"
3049+
reveal_type(m.a[0]) # N: Revealed type is "__main__.A"
3050+
reveal_type(m.a[0].a) # N: Revealed type is "Union[Literal['Hello'], Literal['World']]"

0 commit comments

Comments
 (0)