Skip to content

Commit 541e83b

Browse files
vtjnashpchintalapudioscardssmith
authored andcommitted
inference: apply tmerge limit elementwise to the Union (JuliaLang#50927)
This allows forming larger unions, as long as each element in the Union is both relatively distinct and relatively simple. For example: tmerge(Base.BitSigned, Nothing) == Union{Nothing, Int128, Int16, Int32, Int64, Int8} tmerge(Tuple{Base.BitSigned, Int}, Nothing) == Union{Nothing, Tuple{Any, Int64}} tmerge(AbstractVector{Int}, Vector) == AbstractVector Disables a test from dc8d885, which does not seem possible to handle currently. This makes somewhat drastic changes to make this algorithm more commutative and simpler, since we dropped the final widening to `Any`. Co-authored-by: pchintalapudi <34727397+pchintalapudi@users.noreply.github.com> Co-authored-by: Oscar Smith <oscardssmith@gmail.com>
1 parent 9dcccc4 commit 541e83b

File tree

5 files changed

+163
-63
lines changed

5 files changed

+163
-63
lines changed

base/compiler/abstractinterpretation.jl

+1-5
Original file line numberDiff line numberDiff line change
@@ -3122,11 +3122,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
31223122
ssavaluetypes[currpc] = Any
31233123
continue
31243124
end
3125-
if !isempty(frame.ssavalue_uses[currpc])
3126-
record_ssa_assign!(𝕃ᵢ, currpc, type, frame)
3127-
else
3128-
ssavaluetypes[currpc] = type
3129-
end
3125+
record_ssa_assign!(𝕃ᵢ, currpc, type, frame)
31303126
end # for currpc in bbstart:bbend
31313127

31323128
# Case 1: Fallthrough termination

base/compiler/inferencestate.jl

+2-5
Original file line numberDiff line numberDiff line change
@@ -576,11 +576,8 @@ _topmod(sv::InferenceState) = _topmod(frame_module(sv))
576576
function record_ssa_assign!(𝕃ᵢ::AbstractLattice, ssa_id::Int, @nospecialize(new), frame::InferenceState)
577577
ssavaluetypes = frame.ssavaluetypes
578578
old = ssavaluetypes[ssa_id]
579-
if old === NOT_FOUND || !(𝕃ᵢ, new, old)
580-
# typically, we expect that old ⊑ new (that output information only
581-
# gets less precise with worse input information), but to actually
582-
# guarantee convergence we need to use tmerge here to ensure that is true
583-
ssavaluetypes[ssa_id] = old === NOT_FOUND ? new : tmerge(𝕃ᵢ, old, new)
579+
if old === NOT_FOUND || !is_lattice_equal(𝕃ᵢ, new, old)
580+
ssavaluetypes[ssa_id] = new
584581
W = frame.ip
585582
for r in frame.ssavalue_uses[ssa_id]
586583
if was_reached(frame, r)

base/compiler/typelimits.jl

+87-25
Original file line numberDiff line numberDiff line change
@@ -292,10 +292,17 @@ union_count_abstract(x::Union) = union_count_abstract(x.a) + union_count_abstrac
292292
union_count_abstract(@nospecialize(x)) = !isdispatchelem(x)
293293

294294
function issimpleenoughtype(@nospecialize t)
295+
ut = unwrap_unionall(t)
296+
ut isa DataType && ut.name.wrapper == t && return true
295297
return unionlen(t) + union_count_abstract(t) <= MAX_TYPEUNION_LENGTH &&
296298
unioncomplexity(t) <= MAX_TYPEUNION_COMPLEXITY
297299
end
298300

301+
# We may want to apply a stricter limit than issimpleenoughtype to
302+
# tupleelements individually, to try to keep the whole tuple under the limit,
303+
# even after complicated recursion and other operations on it elsewhere
304+
const issimpleenoughtupleelem = issimpleenoughtype
305+
299306
# A simplified type_more_complex query over the extended lattice
300307
# (assumes typeb ⊑ typea)
301308
@nospecializeinfer function issimplertype(𝕃::AbstractLattice, @nospecialize(typea), @nospecialize(typeb))
@@ -679,6 +686,33 @@ end
679686
return tmerge_types_slow(typea, typeb)
680687
end
681688

689+
@nospecializeinfer @noinline function tname_intersect(aname::Core.TypeName, bname::Core.TypeName)
690+
aname === bname && return aname
691+
if !isabstracttype(aname.wrapper) && !isabstracttype(bname.wrapper)
692+
return nothing # fast path
693+
end
694+
Any.name === aname && return aname
695+
a = unwrap_unionall(aname.wrapper)
696+
heighta = 0
697+
while a !== Any
698+
heighta += 1
699+
a = a.super
700+
end
701+
b = unwrap_unionall(bname.wrapper)
702+
heightb = 0
703+
while b !== Any
704+
b.name === aname && return aname
705+
heightb += 1
706+
b = b.super
707+
end
708+
a = unwrap_unionall(aname.wrapper)
709+
while heighta > heightb
710+
a = a.super
711+
heighta -= 1
712+
end
713+
return a.name === bname ? bname : nothing
714+
end
715+
682716
@nospecializeinfer @noinline function tmerge_types_slow(@nospecialize(typea::Type), @nospecialize(typeb::Type))
683717
# collect the list of types from past tmerge calls returning Union
684718
# and then reduce over that list
@@ -702,74 +736,95 @@ end
702736
# see if any of the union elements have the same TypeName
703737
# in which case, simplify this tmerge by replacing it with
704738
# the widest possible version of itself (the wrapper)
739+
simplify = falses(length(types))
705740
for i in 1:length(types)
741+
typenames[i] === Any.name && continue
706742
ti = types[i]
707743
for j in (i + 1):length(types)
708-
if typenames[i] === typenames[j]
744+
typenames[j] === Any.name && continue
745+
ijname = tname_intersect(typenames[i], typenames[j])
746+
if !(ijname === nothing)
709747
tj = types[j]
710748
if ti <: tj
711749
types[i] = Union{}
712750
typenames[i] = Any.name
751+
simplify[i] = false
752+
simplify[j] = true
713753
break
714754
elseif tj <: ti
715755
types[j] = Union{}
716756
typenames[j] = Any.name
757+
simplify[j] = false
758+
simplify[i] = true
717759
else
718-
if typenames[i] === Tuple.name
760+
if ijname === Tuple.name
719761
# try to widen Tuple slower: make a single non-concrete Tuple containing both
720762
# converge the Tuple element-wise if they are the same length
721763
# see 4ee2b41552a6bc95465c12ca66146d69b354317b, be59686f7613a2ccfd63491c7b354d0b16a95c05,
722764
widen = tuplemerge(unwrap_unionall(ti)::DataType, unwrap_unionall(tj)::DataType)
723765
widen = rewrap_unionall(rewrap_unionall(widen, ti), tj)
766+
simplify[j] = false
724767
else
725-
wr = typenames[i].wrapper
768+
wr = ijname.wrapper
726769
uw = unwrap_unionall(wr)::DataType
727770
ui = unwrap_unionall(ti)::DataType
771+
while ui.name !== ijname
772+
ui = ui.super
773+
end
728774
uj = unwrap_unionall(tj)::DataType
729-
merged = wr
775+
while uj.name !== ijname
776+
uj = uj.super
777+
end
778+
p = Vector{Any}(undef, length(uw.parameters))
779+
usep = true
780+
widen = wr
730781
for k = 1:length(uw.parameters)
731782
ui_k = ui.parameters[k]
732783
if ui_k === uj.parameters[k] && !has_free_typevars(ui_k)
733-
merged = merged{ui_k}
784+
p[k] = ui_k
785+
usep = true
734786
else
735-
merged = merged{uw.parameters[k]}
787+
p[k] = uw.parameters[k]
736788
end
737789
end
738-
widen = rewrap_unionall(merged, wr)
790+
if usep
791+
widen = rewrap_unionall(wr{p...}, wr)
792+
end
793+
simplify[j] = !usep
739794
end
740795
types[i] = Union{}
741796
typenames[i] = Any.name
797+
simplify[i] = false
742798
types[j] = widen
743799
break
744800
end
745801
end
746802
end
747803
end
748-
u = Union{types...}
749-
# don't let type unions get too big, if the above didn't reduce it enough
750-
if issimpleenoughtype(u)
751-
return u
752-
end
753-
# don't let the slow widening of Tuple cause the whole type to grow too fast
804+
# don't let elements of the union get too big, if the above didn't reduce something enough
754805
# Specifically widen Tuple{..., Union{lots of stuff}...} to Tuple{..., Any, ...}
806+
# Don't let Val{<:Val{<:Val}} keep nesting abstract levels either
755807
for i in 1:length(types)
808+
simplify[i] || continue
809+
ti = types[i]
810+
issimpleenoughtype(ti) && continue
756811
if typenames[i] === Tuple.name
757-
ti = types[i]
758-
tip = (unwrap_unionall(types[i])::DataType).parameters
812+
# otherwise we need to do a simple version of tuplemerge for one element now
813+
tip = (unwrap_unionall(ti)::DataType).parameters
759814
lt = length(tip)
760815
p = Vector{Any}(undef, lt)
761816
for j = 1:lt
762817
ui = tip[j]
763-
p[j] = (unioncomplexity(ui)==0) ? ui : isvarargtype(ui) ? Vararg : Any
818+
p[j] = issimpleenoughtupleelem(unwrapva(ui)) ? ui : isvarargtype(ui) ? Vararg : Any
764819
end
765820
types[i] = rewrap_unionall(Tuple{p...}, ti)
821+
else
822+
# this element is not simple enough yet, make it so now
823+
types[i] = typenames[i].wrapper
766824
end
767825
end
768826
u = Union{types...}
769-
if issimpleenoughtype(u)
770-
return u
771-
end
772-
return Any
827+
return u
773828
end
774829

775830
# the inverse of switchtupleunion, with limits on max element union size
@@ -791,7 +846,7 @@ function tuplemerge(a::DataType, b::DataType)
791846
p = Vector{Any}(undef, lt + vt)
792847
for i = 1:lt
793848
ui = Union{ap[i], bp[i]}
794-
p[i] = issimpleenoughtype(ui) ? ui : Any
849+
p[i] = issimpleenoughtupleelem(ui) ? ui : Any
795850
end
796851
# merge the remaining tail into a single, simple Tuple{Vararg{T}} (#22120)
797852
if vt
@@ -809,8 +864,10 @@ function tuplemerge(a::DataType, b::DataType)
809864
# or (equivalently?) iteratively took super-types until reaching a common wrapper
810865
# e.g. consider the results of `tuplemerge(Tuple{Complex}, Tuple{Number, Int})` and of
811866
# `tuplemerge(Tuple{Int}, Tuple{String}, Tuple{Int, String})`
812-
if !(ti <: tail)
813-
if tail <: ti
867+
# c.f. tname_intersect in the algorithm above
868+
hasfree = has_free_typevars(ti)
869+
if hasfree || !(ti <: tail)
870+
if !hasfree && tail <: ti
814871
tail = ti # widen to ti
815872
else
816873
uw = unwrap_unionall(tail)
@@ -838,11 +895,16 @@ function tuplemerge(a::DataType, b::DataType)
838895
end
839896
end
840897
end
841-
tail === Any && return Tuple # short-circuit loop
898+
tail === Any && return Tuple # short-circuit loops
842899
end
843900
end
844901
@assert !(tail === Union{})
845-
p[lt + 1] = Vararg{tail}
902+
if !issimpleenoughtupleelem(tail) || tail === Any
903+
p[lt + 1] = Vararg
904+
lt == 0 && return Tuple
905+
else
906+
p[lt + 1] = Vararg{tail}
907+
end
846908
end
847909
return Tuple{p...}
848910
end

base/compiler/typeutils.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ end
291291
unioncomplexity(@nospecialize x) = _unioncomplexity(x)::Int
292292
function _unioncomplexity(@nospecialize x)
293293
if isa(x, DataType)
294-
x.name === Tuple.name || isvarargtype(x) || return 0
294+
x.name === Tuple.name || return 0
295295
c = 0
296296
for ti in x.parameters
297297
c = max(c, unioncomplexity(ti))
@@ -302,7 +302,7 @@ function _unioncomplexity(@nospecialize x)
302302
elseif isa(x, UnionAll)
303303
return max(unioncomplexity(x.body), unioncomplexity(x.var.ub))
304304
elseif isa(x, TypeofVararg)
305-
return isdefined(x, :T) ? unioncomplexity(x.T) : 0
305+
return isdefined(x, :T) ? unioncomplexity(x.T) + 1 : 1
306306
else
307307
return 0
308308
end

0 commit comments

Comments
 (0)