Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5567,7 +5567,10 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
# capture variable may depend on multiple patterns (it
# will be a union of all capture types). This pass ignores
# guard expressions.
pattern_types = [self.pattern_checker.accept(p, subject_type) for p in s.patterns]
pattern_types = [
self.pattern_checker.accept(p, subject_type, [unwrapped_subject])
for p in s.patterns
]
type_maps: list[TypeMap] = [t.captures for t in pattern_types]
inferred_types = self.infer_variable_types_from_type_maps(type_maps)

Expand All @@ -5577,7 +5580,9 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
current_subject_type = self.expr_checker.narrow_type_from_binder(
named_subject, subject_type
)
pattern_type = self.pattern_checker.accept(p, current_subject_type)
pattern_type = self.pattern_checker.accept(
p, current_subject_type, [unwrapped_subject]
)
with self.binder.frame_context(can_skip=True, fall_through=2):
if b.is_unreachable or isinstance(
get_proper_type(pattern_type.type), UninhabitedType
Expand Down
150 changes: 116 additions & 34 deletions mypy/checkpattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,25 @@
from mypy.checkmember import analyze_member_access
from mypy.expandtype import expand_type_by_instance
from mypy.join import join_types
from mypy.literals import literal_hash
from mypy.literals import Key, literal_hash
from mypy.maptype import map_instance_to_supertype
from mypy.meet import narrow_declared_type
from mypy.messages import MessageBuilder
from mypy.nodes import ARG_POS, Context, Expression, NameExpr, TypeAlias, TypeInfo, Var
from mypy.nodes import (
ARG_POS,
Context,
Expression,
IndexExpr,
IntExpr,
ListExpr,
MemberExpr,
NameExpr,
TupleExpr,
TypeAlias,
TypeInfo,
UnaryExpr,
Var,
)
from mypy.options import Options
from mypy.patterns import (
AsPattern,
Expand Down Expand Up @@ -96,10 +110,8 @@ class PatternChecker(PatternVisitor[PatternType]):
msg: MessageBuilder
# Currently unused
plugin: Plugin
# The expression being matched against the pattern
subject: Expression

subject_type: Type
# The expressions being matched against the (sub)pattern
subject_context: list[list[Expression]]
# Type of the subject to check the (sub)pattern against
type_context: list[Type]
# Types that match against self instead of their __match_args__ if used as a class pattern
Expand All @@ -118,24 +130,28 @@ def __init__(
self.msg = msg
self.plugin = plugin

self.subject_context = []
self.type_context = []
self.self_match_types = self.generate_types_from_names(self_match_type_names)
self.non_sequence_match_types = self.generate_types_from_names(
non_sequence_match_type_names
)
self.options = options

def accept(self, o: Pattern, type_context: Type) -> PatternType:
def accept(self, o: Pattern, type_context: Type, subject: list[Expression]) -> PatternType:
self.subject_context.append(subject)
self.type_context.append(type_context)
result = o.accept(self)
self.subject_context.pop()
self.type_context.pop()

return result

def visit_as_pattern(self, o: AsPattern) -> PatternType:
current_subject = self.subject_context[-1]
current_type = self.type_context[-1]
if o.pattern is not None:
pattern_type = self.accept(o.pattern, current_type)
pattern_type = self.accept(o.pattern, current_type, current_subject)
typ, rest_type, type_map = pattern_type
else:
typ, rest_type, type_map = current_type, UninhabitedType(), {}
Expand All @@ -150,14 +166,15 @@ def visit_as_pattern(self, o: AsPattern) -> PatternType:
return PatternType(typ, rest_type, type_map)

def visit_or_pattern(self, o: OrPattern) -> PatternType:
current_subject = self.subject_context[-1]
current_type = self.type_context[-1]

#
# Check all the subpatterns
#
pattern_types = []
pattern_types: list[PatternType] = []
for pattern in o.patterns:
pattern_type = self.accept(pattern, current_type)
pattern_type = self.accept(pattern, current_type, current_subject)
pattern_types.append(pattern_type)
if not is_uninhabited(pattern_type.type):
current_type = pattern_type.rest_type
Expand All @@ -173,28 +190,42 @@ def visit_or_pattern(self, o: OrPattern) -> PatternType:
#
# Check the capture types
#
capture_types: dict[Var, list[tuple[Expression, Type]]] = defaultdict(list)
capture_types: dict[Var, dict[Key | None, list[tuple[Expression, Type]]]] = defaultdict(
lambda: defaultdict(list)
)
capture_expr_keys: set[Key | None] = set()
# Collect captures from the first subpattern
for expr, typ in pattern_types[0].captures.items():
node = get_var(expr)
capture_types[node].append((expr, typ))
if (node := get_var(expr)) is None:
continue
key = literal_hash(expr)
capture_types[node][key].append((expr, typ))
if isinstance(expr, NameExpr):
capture_expr_keys.add(key)

# Check if other subpatterns capture the same names
for i, pattern_type in enumerate(pattern_types[1:]):
vars = {get_var(expr) for expr, _ in pattern_type.captures.items()}
if capture_types.keys() != vars:
vars = {
literal_hash(expr) for expr in pattern_type.captures if isinstance(expr, NameExpr)
}
if capture_expr_keys != vars:
# Only fail for directly captured names (with NameExpr)
self.msg.fail(message_registry.OR_PATTERN_ALTERNATIVE_NAMES, o.patterns[i])
for expr, typ in pattern_type.captures.items():
node = get_var(expr)
capture_types[node].append((expr, typ))
if (node := get_var(expr)) is None:
continue
key = literal_hash(expr)
capture_types[node][key].append((expr, typ))

captures: dict[Expression, Type] = {}
for capture_list in capture_types.values():
typ = UninhabitedType()
for _, other in capture_list:
typ = make_simplified_union([typ, other])
for expressions in capture_types.values():
for key, capture_list in expressions.items():
if other_types := [entry[1] for entry in capture_list]:
typ = make_simplified_union(other_types)
else:
typ = UninhabitedType()

captures[capture_list[0][0]] = typ
captures[capture_list[0][0]] = typ

union_type = make_simplified_union(types)
return PatternType(union_type, current_type, captures)
Expand Down Expand Up @@ -284,12 +315,37 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
contracted_inner_types = self.contract_starred_pattern_types(
inner_types, star_position, required_patterns
)
for p, t in zip(o.patterns, contracted_inner_types):
pattern_type = self.accept(p, t)
current_subjects: list[list[Expression]] = [[] for _ in range(len(contracted_inner_types))]
end_pos = len(contracted_inner_types) if star_position is None else star_position
for subject in self.subject_context[-1]:
if isinstance(subject, (ListExpr, TupleExpr)):
# For list and tuple expressions, lookup expression in items
for i in range(end_pos):
if i < len(subject.items):
current_subjects[i].append(subject.items[i])
if star_position is not None:
for i in range(star_position + 1, len(contracted_inner_types)):
offset = len(contracted_inner_types) - i
if offset <= len(subject.items):
current_subjects[i].append(subject.items[-offset])
else:
# Support x[0], x[1], ... lookup until wildcard
for i in range(end_pos):
current_subjects[i].append(IndexExpr(subject, IntExpr(i)))
# For everything after wildcard use x[-2], x[-1]
for i in range((star_position or -1) + 1, len(contracted_inner_types)):
offset = len(contracted_inner_types) - i
current_subjects[i].append(IndexExpr(subject, UnaryExpr("-", IntExpr(offset))))
for p, t, s in zip(o.patterns, contracted_inner_types, current_subjects):
pattern_type = self.accept(p, t, s)
typ, rest, type_map = pattern_type
contracted_new_inner_types.append(typ)
contracted_rest_inner_types.append(rest)
self.update_type_map(captures, type_map)
if s:
self.update_type_map(
captures, {subject: typ for subject in s}, fail_multiple_assignments=False
)

new_inner_types = self.expand_starred_pattern_types(
contracted_new_inner_types, star_position, len(inner_types), unpack_index is not None
Expand Down Expand Up @@ -473,11 +529,18 @@ def visit_mapping_pattern(self, o: MappingPattern) -> PatternType:
if inner_type is None:
can_match = False
inner_type = self.chk.named_type("builtins.object")
pattern_type = self.accept(value, inner_type)
current_subjects: list[Expression] = [
IndexExpr(s, key) for s in self.subject_context[-1]
]
pattern_type = self.accept(value, inner_type, current_subjects)
if is_uninhabited(pattern_type.type):
can_match = False
else:
self.update_type_map(captures, pattern_type.captures)
if current_subjects:
self.update_type_map(
captures, {subject: pattern_type.type for subject in current_subjects}
)

if o.rest is not None:
mapping = self.chk.named_type("typing.Mapping")
Expand Down Expand Up @@ -581,7 +644,7 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType:
if self.should_self_match(typ):
if len(o.positionals) > 1:
self.msg.fail(message_registry.CLASS_PATTERN_TOO_MANY_POSITIONAL_ARGS, o)
pattern_type = self.accept(o.positionals[0], narrowed_type)
pattern_type = self.accept(o.positionals[0], narrowed_type, [])
if not is_uninhabited(pattern_type.type):
return PatternType(
pattern_type.type,
Expand Down Expand Up @@ -681,11 +744,20 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType:
elif keyword is not None:
new_type = self.chk.add_any_attribute_to_type(new_type, keyword)

inner_type, inner_rest_type, inner_captures = self.accept(pattern, key_type)
current_subjects: list[Expression] = []
if keyword is not None:
current_subjects = [MemberExpr(s, keyword) for s in self.subject_context[-1]]
inner_type, inner_rest_type, inner_captures = self.accept(
pattern, key_type, current_subjects
)
if is_uninhabited(inner_type):
can_match = False
else:
self.update_type_map(captures, inner_captures)
if current_subjects:
self.update_type_map(
captures, {subject: inner_type for subject in current_subjects}
)
if not is_uninhabited(inner_rest_type):
rest_type = current_type

Expand Down Expand Up @@ -732,17 +804,22 @@ def generate_types_from_names(self, type_names: list[str]) -> list[Type]:
return types

def update_type_map(
self, original_type_map: dict[Expression, Type], extra_type_map: dict[Expression, Type]
self,
original_type_map: dict[Expression, Type],
extra_type_map: dict[Expression, Type],
fail_multiple_assignments: bool = True,
) -> None:
# Calculating this would not be needed if TypeMap directly used literal hashes instead of
# expressions, as suggested in the TODO above it's definition
already_captured = {literal_hash(expr) for expr in original_type_map}
for expr, typ in extra_type_map.items():
if literal_hash(expr) in already_captured:
node = get_var(expr)
self.msg.fail(
message_registry.MULTIPLE_ASSIGNMENTS_IN_PATTERN.format(node.name), expr
)
if (node := get_var(expr)) is None:
continue
if fail_multiple_assignments:
self.msg.fail(
message_registry.MULTIPLE_ASSIGNMENTS_IN_PATTERN.format(node.name), expr
)
else:
original_type_map[expr] = typ

Expand Down Expand Up @@ -794,12 +871,17 @@ def get_match_arg_names(typ: TupleType) -> list[str | None]:
return args


def get_var(expr: Expression) -> Var:
def get_var(expr: Expression) -> Var | None:
"""
Warning: this in only true for expressions captured by a match statement.
Don't call it from anywhere else
"""
assert isinstance(expr, NameExpr), expr
if isinstance(expr, MemberExpr):
return get_var(expr.expr)
if isinstance(expr, IndexExpr):
return get_var(expr.base)
if not isinstance(expr, NameExpr):
return None
node = expr.node
assert isinstance(node, Var), node
return node
Expand Down
2 changes: 1 addition & 1 deletion mypy/literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def visit_set_expr(self, e: SetExpr) -> Key | None:
return self.seq_expr(e, "Set")

def visit_index_expr(self, e: IndexExpr) -> Key | None:
if literal(e.index) == LITERAL_YES:
if literal(e.index) != LITERAL_NO:
return ("Index", literal_hash(e.base), literal_hash(e.index))
return None

Expand Down
Loading
Loading