Skip to content

Commit d9735c9

Browse files
authored
Extend conversion (#686)
1 parent 429ad16 commit d9735c9

File tree

5 files changed

+23
-15
lines changed

5 files changed

+23
-15
lines changed

ext/IntervalArithmeticForwardDiffExt.jl

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module IntervalArithmeticForwardDiffExt
22

33
using IntervalArithmetic, ForwardDiff
4-
using ForwardDiff: Dual, , value, partials
4+
using ForwardDiff: Dual, Partials, , value, partials
55

66
#
77

@@ -70,4 +70,8 @@ function Base.:(^)(x::ExactReal, y::Dual{<:Ty}) where {Ty}
7070
end
7171
end
7272

73+
# resolve ambiguity
74+
75+
Base.convert(::Type{Dual{T,V,N}}, x::ExactReal) where {T,V,N} = Dual{T}(V(x), zero(Partials{N,V}))
76+
7377
end

src/intervals/construction.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -367,10 +367,6 @@ isguaranteed(x::Complex{<:Interval}) = isguaranteed(real(x)) & isguaranteed(imag
367367

368368
isguaranteed(::Number) = false
369369

370-
Interval{T}(x::Interval) where {T<:NumTypes} = convert(Interval{T}, x) # needed to resolve method ambiguity
371-
# Interval{T}(x) where {T<:NumTypes} = convert(Interval{T}, x)
372-
# Interval{T}(x::Interval{T}) where {T<:NumTypes} = convert(Interval{T}, x) # needed to resolve method ambiguity
373-
374370
#
375371

376372
"""
@@ -563,6 +559,10 @@ Base.promote_rule(::Type{T}, ::Type{Interval{S}}) where {T<:AbstractIrrational,S
563559

564560
# conversion
565561

562+
Interval{T}(x::Real) where {T<:NumTypes} = convert(Interval{T}, x)
563+
Interval(x::Real) = Interval{promote_numtype(numtype(x), numtype(x))}(x)
564+
Interval{T}(x::Interval) where {T<:NumTypes} = convert(Interval{T}, x) # needed to resolve method ambiguity
565+
566566
Base.convert(::Type{Interval{T}}, x::Interval) where {T<:NumTypes} = interval(T, x)
567567

568568
function Base.convert(::Type{Interval{T}}, x::Complex{<:Interval}) where {T<:NumTypes}

src/intervals/exact_literals.jl

+5-1
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ Base.promote_rule(::Type{ExactReal{T}}, ::Type{ExactReal{S}}) where {T<:Real,S<:
8787

8888
# to BareInterval
8989

90+
BareInterval{T}(x::ExactReal) where {T<:NumTypes} = convert(BareInterval{T}, x)
91+
BareInterval(x::ExactReal) = BareInterval{promote_numtype(numtype(x.value), numtype(x.value))}(x)
92+
9093
Base.convert(::Type{BareInterval{T}}, x::ExactReal) where {T<:NumTypes} = bareinterval(T, x.value)
9194

9295
Base.promote_rule(::Type{BareInterval{T}}, ::Type{ExactReal{S}}) where {T<:NumTypes,S<:Real} =
@@ -105,8 +108,9 @@ Base.promote_rule(::Type{ExactReal{T}}, ::Type{Interval{S}}) where {T<:Real,S<:N
105108

106109
# to Real
107110

108-
# allows Interval{<:NumTypes}(::ExactReal)
109111
(::Type{T})(x::ExactReal) where {T<:Real} = convert(T, x)
112+
Interval{T}(x::ExactReal) where {T<:NumTypes} = convert(Interval{T}, x) # needed to resolve ambiguity
113+
Interval(x::ExactReal) = Interval{promote_numtype(numtype(x.value), numtype(x.value))}(x) # needed to resolve ambiguity
110114

111115
Base.convert(::Type{T}, x::ExactReal) where {T<:Real} = convert(T, x.value)
112116

test/interval_tests/construction.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ end
4343
@test_throws MethodError BareInterval(1, 2)
4444
@test_throws MethodError BareInterval{Float64}(1, 2)
4545

46-
@test_throws MethodError Interval(1)
47-
@test_throws MethodError Interval{Float64}(1)
46+
@test !isguaranteed(Interval(1))
47+
@test !isguaranteed(Interval{Float64}(1))
4848
@test_throws MethodError Interval(1, 2)
4949
@test_throws MethodError Interval{Float64}(1, 2)
5050

@@ -163,7 +163,7 @@ end
163163
i = interval(IS.Interval(0.1, 2))
164164
@test isequal_interval(i, interval(0.1, 2.)) && !isguaranteed(i)
165165
@test interval(Float64, IS.Interval(0.1, 2)) === i
166-
166+
167167
i = interval(IS.iv"[0.1, Inf)")
168168
@test isequal_interval(i, interval(0.1, Inf)) && !isguaranteed(i)
169169
@test interval(IS.iv"[0.1, Inf]") === nai(Float64)

test/interval_tests/forwarddiff.jl

+6-6
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ end
4646
(t) = ForwardDiff.derivative(ψ, t)
4747
ddψ(t) = ForwardDiff.derivative(dψ, t)
4848
dddψ(t) = ForwardDiff.derivative(ddψ, t)
49-
@test ψ′(0) === (0) && !isguaranteed(ψ′(0))
50-
@test_broken ψ′′(0) === ddψ(0) && !isguaranteed(ψ′′(0)) # rely on `Interval{T}(::Real)` being defined
51-
@test_broken ψ′′′(0) === dddψ(0) && !isguaranteed(ψ′′′(0)) # rely on `Interval{T}(::Real)` being defined
49+
@test ψ′(0) === (0) && !isguaranteed(ψ′(0))
50+
@test ψ′′(0) === ddψ(0) && !isguaranteed(ψ′′(0))
51+
@test ψ′′′(0) === dddψ(0) && !isguaranteed(ψ′′′(0))
5252
t₀ = interval(0)
5353
@test ψ′(t₀) === (t₀) && isguaranteed(ψ′(t₀))
5454
@test ψ′′(t₀) === ddψ(t₀) && isguaranteed(ψ′′(t₀))
@@ -73,14 +73,14 @@ end
7373
@test isguaranteed(dfdy)
7474
@test isguaranteed(grad[1])
7575
@test isguaranteed(grad[2])
76-
76+
7777
if iszero(x) && y < 0
7878
@test decoration(dfdx) == trv
7979
else
8080
@test in_interval(ForwardDiff.derivative(fx, x), dfdx)
8181
end
8282

83-
if iszero(x) && y <= 0
83+
if iszero(x) && y <= 0
8484
@test decoration(dfdy) == trv
8585
else
8686
@test in_interval(ForwardDiff.derivative(fy, y), dfdy)
@@ -104,4 +104,4 @@ end
104104
@exact g(x) = 2^x + 6sin(x^3) - 33
105105
@test isguaranteed(ForwardDiff.derivative(f, interval(1)))
106106
end
107-
end
107+
end

0 commit comments

Comments
 (0)