10
10
from mypy .checkmember import analyze_member_access
11
11
from mypy .expandtype import expand_type_by_instance
12
12
from mypy .join import join_types
13
- from mypy .literals import literal_hash
13
+ from mypy .literals import Key , literal_hash
14
14
from mypy .maptype import map_instance_to_supertype
15
15
from mypy .meet import narrow_declared_type
16
16
from mypy .messages import MessageBuilder
17
- from mypy .nodes import ARG_POS , Context , Expression , NameExpr , TypeAlias , TypeInfo , Var
17
+ from mypy .nodes import (
18
+ ARG_POS ,
19
+ Context ,
20
+ Expression ,
21
+ IndexExpr ,
22
+ IntExpr ,
23
+ ListExpr ,
24
+ MemberExpr ,
25
+ NameExpr ,
26
+ TupleExpr ,
27
+ TypeAlias ,
28
+ TypeInfo ,
29
+ UnaryExpr ,
30
+ Var ,
31
+ )
18
32
from mypy .options import Options
19
33
from mypy .patterns import (
20
34
AsPattern ,
@@ -96,10 +110,8 @@ class PatternChecker(PatternVisitor[PatternType]):
96
110
msg : MessageBuilder
97
111
# Currently unused
98
112
plugin : Plugin
99
- # The expression being matched against the pattern
100
- subject : Expression
101
-
102
- subject_type : Type
113
+ # The expressions being matched against the (sub)pattern
114
+ subject_context : list [list [Expression ]]
103
115
# Type of the subject to check the (sub)pattern against
104
116
type_context : list [Type ]
105
117
# Types that match against self instead of their __match_args__ if used as a class pattern
@@ -118,24 +130,28 @@ def __init__(
118
130
self .msg = msg
119
131
self .plugin = plugin
120
132
133
+ self .subject_context = []
121
134
self .type_context = []
122
135
self .self_match_types = self .generate_types_from_names (self_match_type_names )
123
136
self .non_sequence_match_types = self .generate_types_from_names (
124
137
non_sequence_match_type_names
125
138
)
126
139
self .options = options
127
140
128
- def accept (self , o : Pattern , type_context : Type ) -> PatternType :
141
+ def accept (self , o : Pattern , type_context : Type , subject : list [Expression ]) -> PatternType :
142
+ self .subject_context .append (subject )
129
143
self .type_context .append (type_context )
130
144
result = o .accept (self )
145
+ self .subject_context .pop ()
131
146
self .type_context .pop ()
132
147
133
148
return result
134
149
135
150
def visit_as_pattern (self , o : AsPattern ) -> PatternType :
151
+ current_subject = self .subject_context [- 1 ]
136
152
current_type = self .type_context [- 1 ]
137
153
if o .pattern is not None :
138
- pattern_type = self .accept (o .pattern , current_type )
154
+ pattern_type = self .accept (o .pattern , current_type , current_subject )
139
155
typ , rest_type , type_map = pattern_type
140
156
else :
141
157
typ , rest_type , type_map = current_type , UninhabitedType (), {}
@@ -150,14 +166,15 @@ def visit_as_pattern(self, o: AsPattern) -> PatternType:
150
166
return PatternType (typ , rest_type , type_map )
151
167
152
168
def visit_or_pattern (self , o : OrPattern ) -> PatternType :
169
+ current_subject = self .subject_context [- 1 ]
153
170
current_type = self .type_context [- 1 ]
154
171
155
172
#
156
173
# Check all the subpatterns
157
174
#
158
- pattern_types = []
175
+ pattern_types : list [ PatternType ] = []
159
176
for pattern in o .patterns :
160
- pattern_type = self .accept (pattern , current_type )
177
+ pattern_type = self .accept (pattern , current_type , current_subject )
161
178
pattern_types .append (pattern_type )
162
179
if not is_uninhabited (pattern_type .type ):
163
180
current_type = pattern_type .rest_type
@@ -173,28 +190,42 @@ def visit_or_pattern(self, o: OrPattern) -> PatternType:
173
190
#
174
191
# Check the capture types
175
192
#
176
- capture_types : dict [Var , list [tuple [Expression , Type ]]] = defaultdict (list )
193
+ capture_types : dict [Var , dict [Key | None , list [tuple [Expression , Type ]]]] = defaultdict (
194
+ lambda : defaultdict (list )
195
+ )
196
+ capture_expr_keys : set [Key | None ] = set ()
177
197
# Collect captures from the first subpattern
178
198
for expr , typ in pattern_types [0 ].captures .items ():
179
- node = get_var (expr )
180
- capture_types [node ].append ((expr , typ ))
199
+ if (node := get_var (expr )) is None :
200
+ continue
201
+ key = literal_hash (expr )
202
+ capture_types [node ][key ].append ((expr , typ ))
203
+ if isinstance (expr , NameExpr ):
204
+ capture_expr_keys .add (key )
181
205
182
206
# Check if other subpatterns capture the same names
183
207
for i , pattern_type in enumerate (pattern_types [1 :]):
184
- vars = {get_var (expr ) for expr , _ in pattern_type .captures .items ()}
185
- if capture_types .keys () != vars :
208
+ vars = {
209
+ literal_hash (expr ) for expr in pattern_type .captures if isinstance (expr , NameExpr )
210
+ }
211
+ if capture_expr_keys != vars :
212
+ # Only fail for directly captured names (with NameExpr)
186
213
self .msg .fail (message_registry .OR_PATTERN_ALTERNATIVE_NAMES , o .patterns [i ])
187
214
for expr , typ in pattern_type .captures .items ():
188
- node = get_var (expr )
189
- capture_types [node ].append ((expr , typ ))
215
+ if (node := get_var (expr )) is None :
216
+ continue
217
+ key = literal_hash (expr )
218
+ capture_types [node ][key ].append ((expr , typ ))
190
219
191
220
captures : dict [Expression , Type ] = {}
192
- for capture_list in capture_types .values ():
193
- typ = UninhabitedType ()
194
- for _ , other in capture_list :
195
- typ = make_simplified_union ([typ , other ])
221
+ for expressions in capture_types .values ():
222
+ for key , capture_list in expressions .items ():
223
+ if other_types := [entry [1 ] for entry in capture_list ]:
224
+ typ = make_simplified_union (other_types )
225
+ else :
226
+ typ = UninhabitedType ()
196
227
197
- captures [capture_list [0 ][0 ]] = typ
228
+ captures [capture_list [0 ][0 ]] = typ
198
229
199
230
union_type = make_simplified_union (types )
200
231
return PatternType (union_type , current_type , captures )
@@ -284,12 +315,37 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
284
315
contracted_inner_types = self .contract_starred_pattern_types (
285
316
inner_types , star_position , required_patterns
286
317
)
287
- for p , t in zip (o .patterns , contracted_inner_types ):
288
- pattern_type = self .accept (p , t )
318
+ current_subjects : list [list [Expression ]] = [[] for _ in range (len (contracted_inner_types ))]
319
+ end_pos = len (contracted_inner_types ) if star_position is None else star_position
320
+ for subject in self .subject_context [- 1 ]:
321
+ if isinstance (subject , (ListExpr , TupleExpr )):
322
+ # For list and tuple expressions, lookup expression in items
323
+ for i in range (end_pos ):
324
+ if i < len (subject .items ):
325
+ current_subjects [i ].append (subject .items [i ])
326
+ if star_position is not None :
327
+ for i in range (star_position + 1 , len (contracted_inner_types )):
328
+ offset = len (contracted_inner_types ) - i
329
+ if offset <= len (subject .items ):
330
+ current_subjects [i ].append (subject .items [- offset ])
331
+ else :
332
+ # Support x[0], x[1], ... lookup until wildcard
333
+ for i in range (end_pos ):
334
+ current_subjects [i ].append (IndexExpr (subject , IntExpr (i )))
335
+ # For everything after wildcard use x[-2], x[-1]
336
+ for i in range ((star_position or - 1 ) + 1 , len (contracted_inner_types )):
337
+ offset = len (contracted_inner_types ) - i
338
+ current_subjects [i ].append (IndexExpr (subject , UnaryExpr ("-" , IntExpr (offset ))))
339
+ for p , t , s in zip (o .patterns , contracted_inner_types , current_subjects ):
340
+ pattern_type = self .accept (p , t , s )
289
341
typ , rest , type_map = pattern_type
290
342
contracted_new_inner_types .append (typ )
291
343
contracted_rest_inner_types .append (rest )
292
344
self .update_type_map (captures , type_map )
345
+ if s :
346
+ self .update_type_map (
347
+ captures , {subject : typ for subject in s }, fail_multiple_assignments = False
348
+ )
293
349
294
350
new_inner_types = self .expand_starred_pattern_types (
295
351
contracted_new_inner_types , star_position , len (inner_types ), unpack_index is not None
@@ -473,11 +529,18 @@ def visit_mapping_pattern(self, o: MappingPattern) -> PatternType:
473
529
if inner_type is None :
474
530
can_match = False
475
531
inner_type = self .chk .named_type ("builtins.object" )
476
- pattern_type = self .accept (value , inner_type )
532
+ current_subjects : list [Expression ] = [
533
+ IndexExpr (s , key ) for s in self .subject_context [- 1 ]
534
+ ]
535
+ pattern_type = self .accept (value , inner_type , current_subjects )
477
536
if is_uninhabited (pattern_type .type ):
478
537
can_match = False
479
538
else :
480
539
self .update_type_map (captures , pattern_type .captures )
540
+ if current_subjects :
541
+ self .update_type_map (
542
+ captures , {subject : pattern_type .type for subject in current_subjects }
543
+ )
481
544
482
545
if o .rest is not None :
483
546
mapping = self .chk .named_type ("typing.Mapping" )
@@ -581,7 +644,7 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType:
581
644
if self .should_self_match (typ ):
582
645
if len (o .positionals ) > 1 :
583
646
self .msg .fail (message_registry .CLASS_PATTERN_TOO_MANY_POSITIONAL_ARGS , o )
584
- pattern_type = self .accept (o .positionals [0 ], narrowed_type )
647
+ pattern_type = self .accept (o .positionals [0 ], narrowed_type , [] )
585
648
if not is_uninhabited (pattern_type .type ):
586
649
return PatternType (
587
650
pattern_type .type ,
@@ -681,11 +744,20 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType:
681
744
elif keyword is not None :
682
745
new_type = self .chk .add_any_attribute_to_type (new_type , keyword )
683
746
684
- inner_type , inner_rest_type , inner_captures = self .accept (pattern , key_type )
747
+ current_subjects : list [Expression ] = []
748
+ if keyword is not None :
749
+ current_subjects = [MemberExpr (s , keyword ) for s in self .subject_context [- 1 ]]
750
+ inner_type , inner_rest_type , inner_captures = self .accept (
751
+ pattern , key_type , current_subjects
752
+ )
685
753
if is_uninhabited (inner_type ):
686
754
can_match = False
687
755
else :
688
756
self .update_type_map (captures , inner_captures )
757
+ if current_subjects :
758
+ self .update_type_map (
759
+ captures , {subject : inner_type for subject in current_subjects }
760
+ )
689
761
if not is_uninhabited (inner_rest_type ):
690
762
rest_type = current_type
691
763
@@ -732,17 +804,22 @@ def generate_types_from_names(self, type_names: list[str]) -> list[Type]:
732
804
return types
733
805
734
806
def update_type_map (
735
- self , original_type_map : dict [Expression , Type ], extra_type_map : dict [Expression , Type ]
807
+ self ,
808
+ original_type_map : dict [Expression , Type ],
809
+ extra_type_map : dict [Expression , Type ],
810
+ fail_multiple_assignments : bool = True ,
736
811
) -> None :
737
812
# Calculating this would not be needed if TypeMap directly used literal hashes instead of
738
813
# expressions, as suggested in the TODO above it's definition
739
814
already_captured = {literal_hash (expr ) for expr in original_type_map }
740
815
for expr , typ in extra_type_map .items ():
741
816
if literal_hash (expr ) in already_captured :
742
- node = get_var (expr )
743
- self .msg .fail (
744
- message_registry .MULTIPLE_ASSIGNMENTS_IN_PATTERN .format (node .name ), expr
745
- )
817
+ if (node := get_var (expr )) is None :
818
+ continue
819
+ if fail_multiple_assignments :
820
+ self .msg .fail (
821
+ message_registry .MULTIPLE_ASSIGNMENTS_IN_PATTERN .format (node .name ), expr
822
+ )
746
823
else :
747
824
original_type_map [expr ] = typ
748
825
@@ -794,12 +871,17 @@ def get_match_arg_names(typ: TupleType) -> list[str | None]:
794
871
return args
795
872
796
873
797
- def get_var (expr : Expression ) -> Var :
874
+ def get_var (expr : Expression ) -> Var | None :
798
875
"""
799
876
Warning: this in only true for expressions captured by a match statement.
800
877
Don't call it from anywhere else
801
878
"""
802
- assert isinstance (expr , NameExpr ), expr
879
+ if isinstance (expr , MemberExpr ):
880
+ return get_var (expr .expr )
881
+ if isinstance (expr , IndexExpr ):
882
+ return get_var (expr .base )
883
+ if not isinstance (expr , NameExpr ):
884
+ return None
803
885
node = expr .node
804
886
assert isinstance (node , Var ), node
805
887
return node
0 commit comments