Skip to content

Commit 1da718a

Browse files
committed
Improve match subject inference
1 parent e852829 commit 1da718a

File tree

4 files changed

+261
-37
lines changed

4 files changed

+261
-37
lines changed

mypy/checker.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5567,7 +5567,10 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
55675567
# capture variable may depend on multiple patterns (it
55685568
# will be a union of all capture types). This pass ignores
55695569
# guard expressions.
5570-
pattern_types = [self.pattern_checker.accept(p, subject_type) for p in s.patterns]
5570+
pattern_types = [
5571+
self.pattern_checker.accept(p, subject_type, [unwrapped_subject])
5572+
for p in s.patterns
5573+
]
55715574
type_maps: list[TypeMap] = [t.captures for t in pattern_types]
55725575
inferred_types = self.infer_variable_types_from_type_maps(type_maps)
55735576

@@ -5577,7 +5580,9 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
55775580
current_subject_type = self.expr_checker.narrow_type_from_binder(
55785581
named_subject, subject_type
55795582
)
5580-
pattern_type = self.pattern_checker.accept(p, current_subject_type)
5583+
pattern_type = self.pattern_checker.accept(
5584+
p, current_subject_type, [unwrapped_subject]
5585+
)
55815586
with self.binder.frame_context(can_skip=True, fall_through=2):
55825587
if b.is_unreachable or isinstance(
55835588
get_proper_type(pattern_type.type), UninhabitedType

mypy/checkpattern.py

Lines changed: 116 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,25 @@
1010
from mypy.checkmember import analyze_member_access
1111
from mypy.expandtype import expand_type_by_instance
1212
from mypy.join import join_types
13-
from mypy.literals import literal_hash
13+
from mypy.literals import Key, literal_hash
1414
from mypy.maptype import map_instance_to_supertype
1515
from mypy.meet import narrow_declared_type
1616
from mypy.messages import MessageBuilder
17-
from mypy.nodes import ARG_POS, Context, Expression, NameExpr, TypeAlias, TypeInfo, Var
17+
from mypy.nodes import (
18+
ARG_POS,
19+
Context,
20+
Expression,
21+
IndexExpr,
22+
IntExpr,
23+
ListExpr,
24+
MemberExpr,
25+
NameExpr,
26+
TupleExpr,
27+
TypeAlias,
28+
TypeInfo,
29+
UnaryExpr,
30+
Var,
31+
)
1832
from mypy.options import Options
1933
from mypy.patterns import (
2034
AsPattern,
@@ -96,10 +110,8 @@ class PatternChecker(PatternVisitor[PatternType]):
96110
msg: MessageBuilder
97111
# Currently unused
98112
plugin: Plugin
99-
# The expression being matched against the pattern
100-
subject: Expression
101-
102-
subject_type: Type
113+
# The expressions being matched against the (sub)pattern
114+
subject_context: list[list[Expression]]
103115
# Type of the subject to check the (sub)pattern against
104116
type_context: list[Type]
105117
# Types that match against self instead of their __match_args__ if used as a class pattern
@@ -118,24 +130,28 @@ def __init__(
118130
self.msg = msg
119131
self.plugin = plugin
120132

133+
self.subject_context = []
121134
self.type_context = []
122135
self.self_match_types = self.generate_types_from_names(self_match_type_names)
123136
self.non_sequence_match_types = self.generate_types_from_names(
124137
non_sequence_match_type_names
125138
)
126139
self.options = options
127140

128-
def accept(self, o: Pattern, type_context: Type) -> PatternType:
141+
def accept(self, o: Pattern, type_context: Type, subject: list[Expression]) -> PatternType:
142+
self.subject_context.append(subject)
129143
self.type_context.append(type_context)
130144
result = o.accept(self)
145+
self.subject_context.pop()
131146
self.type_context.pop()
132147

133148
return result
134149

135150
def visit_as_pattern(self, o: AsPattern) -> PatternType:
151+
current_subject = self.subject_context[-1]
136152
current_type = self.type_context[-1]
137153
if o.pattern is not None:
138-
pattern_type = self.accept(o.pattern, current_type)
154+
pattern_type = self.accept(o.pattern, current_type, current_subject)
139155
typ, rest_type, type_map = pattern_type
140156
else:
141157
typ, rest_type, type_map = current_type, UninhabitedType(), {}
@@ -150,14 +166,15 @@ def visit_as_pattern(self, o: AsPattern) -> PatternType:
150166
return PatternType(typ, rest_type, type_map)
151167

152168
def visit_or_pattern(self, o: OrPattern) -> PatternType:
169+
current_subject = self.subject_context[-1]
153170
current_type = self.type_context[-1]
154171

155172
#
156173
# Check all the subpatterns
157174
#
158-
pattern_types = []
175+
pattern_types: list[PatternType] = []
159176
for pattern in o.patterns:
160-
pattern_type = self.accept(pattern, current_type)
177+
pattern_type = self.accept(pattern, current_type, current_subject)
161178
pattern_types.append(pattern_type)
162179
if not is_uninhabited(pattern_type.type):
163180
current_type = pattern_type.rest_type
@@ -173,28 +190,42 @@ def visit_or_pattern(self, o: OrPattern) -> PatternType:
173190
#
174191
# Check the capture types
175192
#
176-
capture_types: dict[Var, list[tuple[Expression, Type]]] = defaultdict(list)
193+
capture_types: dict[Var, dict[Key | None, list[tuple[Expression, Type]]]] = defaultdict(
194+
lambda: defaultdict(list)
195+
)
196+
capture_expr_keys: set[Key | None] = set()
177197
# Collect captures from the first subpattern
178198
for expr, typ in pattern_types[0].captures.items():
179-
node = get_var(expr)
180-
capture_types[node].append((expr, typ))
199+
if (node := get_var(expr)) is None:
200+
continue
201+
key = literal_hash(expr)
202+
capture_types[node][key].append((expr, typ))
203+
if isinstance(expr, NameExpr):
204+
capture_expr_keys.add(key)
181205

182206
# Check if other subpatterns capture the same names
183207
for i, pattern_type in enumerate(pattern_types[1:]):
184-
vars = {get_var(expr) for expr, _ in pattern_type.captures.items()}
185-
if capture_types.keys() != vars:
208+
vars = {
209+
literal_hash(expr) for expr in pattern_type.captures if isinstance(expr, NameExpr)
210+
}
211+
if capture_expr_keys != vars:
212+
# Only fail for directly captured names (with NameExpr)
186213
self.msg.fail(message_registry.OR_PATTERN_ALTERNATIVE_NAMES, o.patterns[i])
187214
for expr, typ in pattern_type.captures.items():
188-
node = get_var(expr)
189-
capture_types[node].append((expr, typ))
215+
if (node := get_var(expr)) is None:
216+
continue
217+
key = literal_hash(expr)
218+
capture_types[node][key].append((expr, typ))
190219

191220
captures: dict[Expression, Type] = {}
192-
for capture_list in capture_types.values():
193-
typ = UninhabitedType()
194-
for _, other in capture_list:
195-
typ = make_simplified_union([typ, other])
221+
for expressions in capture_types.values():
222+
for key, capture_list in expressions.items():
223+
if other_types := [entry[1] for entry in capture_list]:
224+
typ = make_simplified_union(other_types)
225+
else:
226+
typ = UninhabitedType()
196227

197-
captures[capture_list[0][0]] = typ
228+
captures[capture_list[0][0]] = typ
198229

199230
union_type = make_simplified_union(types)
200231
return PatternType(union_type, current_type, captures)
@@ -284,12 +315,37 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
284315
contracted_inner_types = self.contract_starred_pattern_types(
285316
inner_types, star_position, required_patterns
286317
)
287-
for p, t in zip(o.patterns, contracted_inner_types):
288-
pattern_type = self.accept(p, t)
318+
current_subjects: list[list[Expression]] = [[] for _ in range(len(contracted_inner_types))]
319+
end_pos = len(contracted_inner_types) if star_position is None else star_position
320+
for subject in self.subject_context[-1]:
321+
if isinstance(subject, (ListExpr, TupleExpr)):
322+
# For list and tuple expressions, lookup expression in items
323+
for i in range(end_pos):
324+
if i < len(subject.items):
325+
current_subjects[i].append(subject.items[i])
326+
if star_position is not None:
327+
for i in range(star_position + 1, len(contracted_inner_types)):
328+
offset = len(contracted_inner_types) - i
329+
if offset <= len(subject.items):
330+
current_subjects[i].append(subject.items[-offset])
331+
else:
332+
# Support x[0], x[1], ... lookup until wildcard
333+
for i in range(end_pos):
334+
current_subjects[i].append(IndexExpr(subject, IntExpr(i)))
335+
# For everything after wildcard use x[-2], x[-1]
336+
for i in range((star_position or -1) + 1, len(contracted_inner_types)):
337+
offset = len(contracted_inner_types) - i
338+
current_subjects[i].append(IndexExpr(subject, UnaryExpr("-", IntExpr(offset))))
339+
for p, t, s in zip(o.patterns, contracted_inner_types, current_subjects):
340+
pattern_type = self.accept(p, t, s)
289341
typ, rest, type_map = pattern_type
290342
contracted_new_inner_types.append(typ)
291343
contracted_rest_inner_types.append(rest)
292344
self.update_type_map(captures, type_map)
345+
if s:
346+
self.update_type_map(
347+
captures, {subject: typ for subject in s}, fail_multiple_assignments=False
348+
)
293349

294350
new_inner_types = self.expand_starred_pattern_types(
295351
contracted_new_inner_types, star_position, len(inner_types), unpack_index is not None
@@ -473,11 +529,18 @@ def visit_mapping_pattern(self, o: MappingPattern) -> PatternType:
473529
if inner_type is None:
474530
can_match = False
475531
inner_type = self.chk.named_type("builtins.object")
476-
pattern_type = self.accept(value, inner_type)
532+
current_subjects: list[Expression] = [
533+
IndexExpr(s, key) for s in self.subject_context[-1]
534+
]
535+
pattern_type = self.accept(value, inner_type, current_subjects)
477536
if is_uninhabited(pattern_type.type):
478537
can_match = False
479538
else:
480539
self.update_type_map(captures, pattern_type.captures)
540+
if current_subjects:
541+
self.update_type_map(
542+
captures, {subject: pattern_type.type for subject in current_subjects}
543+
)
481544

482545
if o.rest is not None:
483546
mapping = self.chk.named_type("typing.Mapping")
@@ -581,7 +644,7 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType:
581644
if self.should_self_match(typ):
582645
if len(o.positionals) > 1:
583646
self.msg.fail(message_registry.CLASS_PATTERN_TOO_MANY_POSITIONAL_ARGS, o)
584-
pattern_type = self.accept(o.positionals[0], narrowed_type)
647+
pattern_type = self.accept(o.positionals[0], narrowed_type, [])
585648
if not is_uninhabited(pattern_type.type):
586649
return PatternType(
587650
pattern_type.type,
@@ -681,11 +744,20 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType:
681744
elif keyword is not None:
682745
new_type = self.chk.add_any_attribute_to_type(new_type, keyword)
683746

684-
inner_type, inner_rest_type, inner_captures = self.accept(pattern, key_type)
747+
current_subjects: list[Expression] = []
748+
if keyword is not None:
749+
current_subjects = [MemberExpr(s, keyword) for s in self.subject_context[-1]]
750+
inner_type, inner_rest_type, inner_captures = self.accept(
751+
pattern, key_type, current_subjects
752+
)
685753
if is_uninhabited(inner_type):
686754
can_match = False
687755
else:
688756
self.update_type_map(captures, inner_captures)
757+
if current_subjects:
758+
self.update_type_map(
759+
captures, {subject: inner_type for subject in current_subjects}
760+
)
689761
if not is_uninhabited(inner_rest_type):
690762
rest_type = current_type
691763

@@ -732,17 +804,22 @@ def generate_types_from_names(self, type_names: list[str]) -> list[Type]:
732804
return types
733805

734806
def update_type_map(
735-
self, original_type_map: dict[Expression, Type], extra_type_map: dict[Expression, Type]
807+
self,
808+
original_type_map: dict[Expression, Type],
809+
extra_type_map: dict[Expression, Type],
810+
fail_multiple_assignments: bool = True,
736811
) -> None:
737812
# Calculating this would not be needed if TypeMap directly used literal hashes instead of
738813
# expressions, as suggested in the TODO above it's definition
739814
already_captured = {literal_hash(expr) for expr in original_type_map}
740815
for expr, typ in extra_type_map.items():
741816
if literal_hash(expr) in already_captured:
742-
node = get_var(expr)
743-
self.msg.fail(
744-
message_registry.MULTIPLE_ASSIGNMENTS_IN_PATTERN.format(node.name), expr
745-
)
817+
if (node := get_var(expr)) is None:
818+
continue
819+
if fail_multiple_assignments:
820+
self.msg.fail(
821+
message_registry.MULTIPLE_ASSIGNMENTS_IN_PATTERN.format(node.name), expr
822+
)
746823
else:
747824
original_type_map[expr] = typ
748825

@@ -794,12 +871,17 @@ def get_match_arg_names(typ: TupleType) -> list[str | None]:
794871
return args
795872

796873

797-
def get_var(expr: Expression) -> Var:
874+
def get_var(expr: Expression) -> Var | None:
798875
"""
799876
Warning: this in only true for expressions captured by a match statement.
800877
Don't call it from anywhere else
801878
"""
802-
assert isinstance(expr, NameExpr), expr
879+
if isinstance(expr, MemberExpr):
880+
return get_var(expr.expr)
881+
if isinstance(expr, IndexExpr):
882+
return get_var(expr.base)
883+
if not isinstance(expr, NameExpr):
884+
return None
803885
node = expr.node
804886
assert isinstance(node, Var), node
805887
return node

mypy/literals.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def visit_set_expr(self, e: SetExpr) -> Key | None:
228228
return self.seq_expr(e, "Set")
229229

230230
def visit_index_expr(self, e: IndexExpr) -> Key | None:
231-
if literal(e.index) == LITERAL_YES:
231+
if literal(e.index) != LITERAL_NO:
232232
return ("Index", literal_hash(e.base), literal_hash(e.index))
233233
return None
234234

0 commit comments

Comments
 (0)