10
10
from mypy .checkmember import analyze_member_access
11
11
from mypy .expandtype import expand_type_by_instance
12
12
from mypy .join import join_types
13
- from mypy .literals import literal_hash
13
+ from mypy .literals import Key , literal_hash
14
14
from mypy .maptype import map_instance_to_supertype
15
15
from mypy .meet import narrow_declared_type
16
16
from mypy .messages import MessageBuilder
17
- from mypy .nodes import ARG_POS , Context , Expression , NameExpr , TypeAlias , TypeInfo , Var
17
+ from mypy .nodes import (
18
+ ARG_POS ,
19
+ Context ,
20
+ Expression ,
21
+ IndexExpr ,
22
+ IntExpr ,
23
+ MemberExpr ,
24
+ NameExpr ,
25
+ TypeAlias ,
26
+ TypeInfo ,
27
+ UnaryExpr ,
28
+ Var ,
29
+ )
18
30
from mypy .options import Options
19
31
from mypy .patterns import (
20
32
AsPattern ,
@@ -96,10 +108,8 @@ class PatternChecker(PatternVisitor[PatternType]):
96
108
msg : MessageBuilder
97
109
# Currently unused
98
110
plugin : Plugin
99
- # The expression being matched against the pattern
100
- subject : Expression
101
-
102
- subject_type : Type
111
+ # The expressions being matched against the (sub)pattern
112
+ subject_context : list [list [Expression ]]
103
113
# Type of the subject to check the (sub)pattern against
104
114
type_context : list [Type ]
105
115
# Types that match against self instead of their __match_args__ if used as a class pattern
@@ -118,24 +128,28 @@ def __init__(
118
128
self .msg = msg
119
129
self .plugin = plugin
120
130
131
+ self .subject_context = []
121
132
self .type_context = []
122
133
self .self_match_types = self .generate_types_from_names (self_match_type_names )
123
134
self .non_sequence_match_types = self .generate_types_from_names (
124
135
non_sequence_match_type_names
125
136
)
126
137
self .options = options
127
138
128
- def accept (self , o : Pattern , type_context : Type ) -> PatternType :
139
+ def accept (self , o : Pattern , type_context : Type , subject : list [Expression ]) -> PatternType :
140
+ self .subject_context .append (subject )
129
141
self .type_context .append (type_context )
130
142
result = o .accept (self )
143
+ self .subject_context .pop ()
131
144
self .type_context .pop ()
132
145
133
146
return result
134
147
135
148
def visit_as_pattern (self , o : AsPattern ) -> PatternType :
149
+ current_subject = self .subject_context [- 1 ]
136
150
current_type = self .type_context [- 1 ]
137
151
if o .pattern is not None :
138
- pattern_type = self .accept (o .pattern , current_type )
152
+ pattern_type = self .accept (o .pattern , current_type , current_subject )
139
153
typ , rest_type , type_map = pattern_type
140
154
else :
141
155
typ , rest_type , type_map = current_type , UninhabitedType (), {}
@@ -150,14 +164,15 @@ def visit_as_pattern(self, o: AsPattern) -> PatternType:
150
164
return PatternType (typ , rest_type , type_map )
151
165
152
166
def visit_or_pattern (self , o : OrPattern ) -> PatternType :
167
+ current_subject = self .subject_context [- 1 ]
153
168
current_type = self .type_context [- 1 ]
154
169
155
170
#
156
171
# Check all the subpatterns
157
172
#
158
- pattern_types = []
173
+ pattern_types : list [ PatternType ] = []
159
174
for pattern in o .patterns :
160
- pattern_type = self .accept (pattern , current_type )
175
+ pattern_type = self .accept (pattern , current_type , current_subject )
161
176
pattern_types .append (pattern_type )
162
177
if not is_uninhabited (pattern_type .type ):
163
178
current_type = pattern_type .rest_type
@@ -173,28 +188,40 @@ def visit_or_pattern(self, o: OrPattern) -> PatternType:
173
188
#
174
189
# Check the capture types
175
190
#
176
- capture_types : dict [Var , list [tuple [Expression , Type ]]] = defaultdict (list )
191
+ capture_types : dict [Var , dict [Key | None , list [tuple [Expression , Type ]]]] = defaultdict (
192
+ lambda : defaultdict (list )
193
+ )
194
+ capture_expr_keys : set [Key | None ] = set ()
177
195
# Collect captures from the first subpattern
178
196
for expr , typ in pattern_types [0 ].captures .items ():
179
197
node = get_var (expr )
180
- capture_types [node ].append ((expr , typ ))
198
+ key = literal_hash (expr )
199
+ capture_types [node ][key ].append ((expr , typ ))
200
+ if isinstance (expr , NameExpr ):
201
+ capture_expr_keys .add (key )
181
202
182
203
# Check if other subpatterns capture the same names
183
204
for i , pattern_type in enumerate (pattern_types [1 :]):
184
- vars = {get_var (expr ) for expr , _ in pattern_type .captures .items ()}
185
- if capture_types .keys () != vars :
205
+ vars = {
206
+ literal_hash (expr ) for expr in pattern_type .captures if isinstance (expr , NameExpr )
207
+ }
208
+ if capture_expr_keys != vars :
209
+ # Only fail for directly captured names (with NameExpr)
186
210
self .msg .fail (message_registry .OR_PATTERN_ALTERNATIVE_NAMES , o .patterns [i ])
187
211
for expr , typ in pattern_type .captures .items ():
188
212
node = get_var (expr )
189
- capture_types [node ].append ((expr , typ ))
213
+ key = literal_hash (expr )
214
+ capture_types [node ][key ].append ((expr , typ ))
190
215
191
216
captures : dict [Expression , Type ] = {}
192
- for capture_list in capture_types .values ():
193
- typ = UninhabitedType ()
194
- for _ , other in capture_list :
195
- typ = make_simplified_union ([typ , other ])
217
+ for expressions in capture_types .values ():
218
+ for key , capture_list in expressions .items ():
219
+ if other_types := [entry [1 ] for entry in capture_list ]:
220
+ typ = make_simplified_union (other_types )
221
+ else :
222
+ typ = UninhabitedType ()
196
223
197
- captures [capture_list [0 ][0 ]] = typ
224
+ captures [capture_list [0 ][0 ]] = typ
198
225
199
226
union_type = make_simplified_union (types )
200
227
return PatternType (union_type , current_type , captures )
@@ -284,12 +311,24 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
284
311
contracted_inner_types = self .contract_starred_pattern_types (
285
312
inner_types , star_position , required_patterns
286
313
)
287
- for p , t in zip (o .patterns , contracted_inner_types ):
288
- pattern_type = self .accept (p , t )
314
+ current_subjects : list [list [Expression ]] = [[] for _ in range (len (contracted_inner_types ))]
315
+ for s in self .subject_context [- 1 ]:
316
+ # Support x[0], x[1], ... lookup until wildcard
317
+ end_pos = len (contracted_inner_types ) if star_position is None else star_position
318
+ for i in range (end_pos ):
319
+ current_subjects [i ].append (IndexExpr (s , IntExpr (i )))
320
+ # For everything after wildcard use x[-2], x[-1]
321
+ for i in range ((star_position or - 1 ) + 1 , len (contracted_inner_types )):
322
+ offset = len (contracted_inner_types ) - i
323
+ current_subjects [i ].append (IndexExpr (s , UnaryExpr ("-" , IntExpr (offset ))))
324
+ for p , t , s in zip (o .patterns , contracted_inner_types , current_subjects ):
325
+ pattern_type = self .accept (p , t , s )
289
326
typ , rest , type_map = pattern_type
290
327
contracted_new_inner_types .append (typ )
291
328
contracted_rest_inner_types .append (rest )
292
329
self .update_type_map (captures , type_map )
330
+ if s :
331
+ self .update_type_map (captures , {subject : typ for subject in s })
293
332
294
333
new_inner_types = self .expand_starred_pattern_types (
295
334
contracted_new_inner_types , star_position , len (inner_types ), unpack_index is not None
@@ -473,11 +512,18 @@ def visit_mapping_pattern(self, o: MappingPattern) -> PatternType:
473
512
if inner_type is None :
474
513
can_match = False
475
514
inner_type = self .chk .named_type ("builtins.object" )
476
- pattern_type = self .accept (value , inner_type )
515
+ current_subjects : list [Expression ] = [
516
+ IndexExpr (s , key ) for s in self .subject_context [- 1 ]
517
+ ]
518
+ pattern_type = self .accept (value , inner_type , current_subjects )
477
519
if is_uninhabited (pattern_type .type ):
478
520
can_match = False
479
521
else :
480
522
self .update_type_map (captures , pattern_type .captures )
523
+ if current_subjects :
524
+ self .update_type_map (
525
+ captures , {subject : pattern_type .type for subject in current_subjects }
526
+ )
481
527
482
528
if o .rest is not None :
483
529
mapping = self .chk .named_type ("typing.Mapping" )
@@ -581,7 +627,7 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType:
581
627
if self .should_self_match (typ ):
582
628
if len (o .positionals ) > 1 :
583
629
self .msg .fail (message_registry .CLASS_PATTERN_TOO_MANY_POSITIONAL_ARGS , o )
584
- pattern_type = self .accept (o .positionals [0 ], narrowed_type )
630
+ pattern_type = self .accept (o .positionals [0 ], narrowed_type , [] )
585
631
if not is_uninhabited (pattern_type .type ):
586
632
return PatternType (
587
633
pattern_type .type ,
@@ -681,11 +727,20 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType:
681
727
elif keyword is not None :
682
728
new_type = self .chk .add_any_attribute_to_type (new_type , keyword )
683
729
684
- inner_type , inner_rest_type , inner_captures = self .accept (pattern , key_type )
730
+ current_subjects : list [Expression ] = []
731
+ if keyword is not None :
732
+ current_subjects = [MemberExpr (s , keyword ) for s in self .subject_context [- 1 ]]
733
+ inner_type , inner_rest_type , inner_captures = self .accept (
734
+ pattern , key_type , current_subjects
735
+ )
685
736
if is_uninhabited (inner_type ):
686
737
can_match = False
687
738
else :
688
739
self .update_type_map (captures , inner_captures )
740
+ if current_subjects :
741
+ self .update_type_map (
742
+ captures , {subject : inner_type for subject in current_subjects }
743
+ )
689
744
if not is_uninhabited (inner_rest_type ):
690
745
rest_type = current_type
691
746
@@ -799,6 +854,10 @@ def get_var(expr: Expression) -> Var:
799
854
Warning: this in only true for expressions captured by a match statement.
800
855
Don't call it from anywhere else
801
856
"""
857
+ if isinstance (expr , MemberExpr ):
858
+ return get_var (expr .expr )
859
+ if isinstance (expr , IndexExpr ):
860
+ return get_var (expr .base )
802
861
assert isinstance (expr , NameExpr ), expr
803
862
node = expr .node
804
863
assert isinstance (node , Var ), node
0 commit comments