From 69cc368812ac8b74aba6ed00ae50c4ea9b9213ed Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 18 Aug 2025 12:34:44 -0400 Subject: [PATCH 01/17] excise stuff in solve and utils --- src/solve.jl | 1266 +++++++++++++++++++++++++------------------------- src/utils.jl | 81 ---- 2 files changed, 633 insertions(+), 714 deletions(-) diff --git a/src/solve.jl b/src/solve.jl index 5bfd45ac9..092b06f59 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -7,539 +7,539 @@ NO_TSPAN_PROBS = Union{AbstractLinearProblem, AbstractNonlinearProblem, AbstractIntegralProblem, AbstractSteadyStateProblem, AbstractJumpProblem} -has_kwargs(_prob::AbstractDEProblem) = has_kwargs(typeof(_prob)) -Base.@pure __has_kwargs(::Type{T}) where {T} = :kwargs ∈ fieldnames(T) -has_kwargs(::Type{T}) where {T} = __has_kwargs(T) - -const allowedkeywords = (:dense, - :saveat, - :save_idxs, - :tstops, - :tspan, - :d_discontinuities, - :save_everystep, - :save_on, - :save_start, - :save_end, - :initialize_save, - :adaptive, - :abstol, - :reltol, - :dt, - :dtmax, - :dtmin, - :force_dtmin, - :internalnorm, - :controller, - :gamma, - :beta1, - :beta2, - :qmax, - :qmin, - :qsteady_min, - :qsteady_max, - :qoldinit, - :failfactor, - :calck, - :alias_u0, - :maxiters, - :maxtime, - :callback, - :isoutofdomain, - :unstable_check, - :verbose, - :merge_callbacks, - :progress, - :progress_steps, - :progress_name, - :progress_message, - :progress_id, - :timeseries_errors, - :dense_errors, - :weak_timeseries_errors, - :weak_dense_errors, - :wrap, - :calculate_error, - :initializealg, - :alg, - :save_noise, - :delta, - :seed, - :alg_hints, - :kwargshandle, - :trajectories, - :batch_size, - :sensealg, - :advance_to_tstop, - :stop_at_next_tstop, - :u0, - :p, - # These two are from the default algorithm handling - :default_set, - :second_time, - # This is for DiffEqDevTools - :prob_choice, - # Jump problems - :alias_jump, - # This is for copying/deepcopying noise in StochasticDiffEq - :alias_noise, - # This is for SimpleNonlinearSolve handling for batched Nonlinear Solves - :batch, - # Shooting method in BVP needs to differentiate between these two categories - :nlsolve_kwargs, - :odesolve_kwargs, - # If Solvers which internally use linsolve - :linsolve_kwargs, - # Solvers internally using EnsembleProblem - :ensemblealg, - # Fine Grained Control of Tracing (Storing and Logging) during Solve - :show_trace, - :trace_level, - :store_trace, - # Termination condition for solvers - :termination_condition, - # For AbstractAliasSpecifier - :alias, - # Parameter estimation with BVP - :fit_parameters) - -const KWARGWARN_MESSAGE = """ - Unrecognized keyword arguments found. - The only allowed keyword arguments to `solve` are: - $allowedkeywords - - See https://diffeq.sciml.ai/stable/basics/common_solver_opts/ for more details. - - Set kwargshandle=KeywordArgError for an error message. - Set kwargshandle=KeywordArgSilent to ignore this message. - """ - -const KWARGERROR_MESSAGE = """ - Unrecognized keyword arguments found. - The only allowed keyword arguments to `solve` are: - $allowedkeywords - - See https://diffeq.sciml.ai/stable/basics/common_solver_opts/ for more details. - """ - -struct CommonKwargError <: Exception - kwargs::Any -end - -function Base.showerror(io::IO, e::CommonKwargError) - println(io, KWARGERROR_MESSAGE) - notin = collect(map(x -> x ∉ allowedkeywords, keys(e.kwargs))) - unrecognized = collect(keys(e.kwargs))[notin] - print(io, "Unrecognized keyword arguments: ") - printstyled(io, unrecognized; bold = true, color = :red) - print(io, "\n\n") - println(io, TruncatedStacktraces.VERBOSE_MSG) -end - -@enum KeywordArgError KeywordArgWarn KeywordArgSilent - -const INCOMPATIBLE_U0_MESSAGE = """ - Initial condition incompatible with functional form. - Detected an in-place function with an initial condition of type Number or SArray. - This is incompatible because Numbers cannot be mutated, i.e. - `x = 2.0; y = 2.0; x .= y` will error. - - If using a immutable initial condition type, please use the out-of-place form. - I.e. define the function `du=f(u,p,t)` instead of attempting to "mutate" the immutable `du`. - - If your differential equation function was defined with multiple dispatches and one is - in-place, then the automatic detection will choose in-place. In this case, override the - choice in the problem constructor, i.e. `ODEProblem{false}(f,u0,tspan,p,kwargs...)`. - - For a longer discussion on mutability vs immutability and in-place vs out-of-place, see: - https://diffeq.sciml.ai/stable/tutorials/faster_ode_example/#Example-Accelerating-a-Non-Stiff-Equation:-The-Lorenz-Equation - """ - -struct IncompatibleInitialConditionError <: Exception end - -function Base.showerror(io::IO, e::IncompatibleInitialConditionError) - print(io, INCOMPATIBLE_U0_MESSAGE) - println(io, TruncatedStacktraces.VERBOSE_MSG) -end - -const NO_DEFAULT_ALGORITHM_MESSAGE = """ - Default algorithm choices require DifferentialEquations.jl. - Please specify an algorithm (e.g., `solve(prob, Tsit5())` or - `init(prob, Tsit5())` for an ODE) or import DifferentialEquations - directly. - - You can find the list of available solvers at https://diffeq.sciml.ai/stable/solvers/ode_solve/ - and its associated pages. - """ - -struct NoDefaultAlgorithmError <: Exception end - -function Base.showerror(io::IO, e::NoDefaultAlgorithmError) - print(io, NO_DEFAULT_ALGORITHM_MESSAGE) - println(io, TruncatedStacktraces.VERBOSE_MSG) -end - -const NO_TSPAN_MESSAGE = """ - No tspan is set in the problem or chosen in the init/solve call - """ - -struct NoTspanError <: Exception end - -function Base.showerror(io::IO, e::NoTspanError) - print(io, NO_TSPAN_MESSAGE) - println(io, TruncatedStacktraces.VERBOSE_MSG) -end - -const NAN_TSPAN_MESSAGE = """ - NaN tspan is set in the problem or chosen in the init/solve call. - Note that -Inf and Inf values are allowed in the timespan for solves - which are terminated via callbacks, however NaN values are not allowed - since the direction of time is undetermined. - """ - -struct NaNTspanError <: Exception end - -function Base.showerror(io::IO, e::NaNTspanError) - print(io, NAN_TSPAN_MESSAGE) - println(io, TruncatedStacktraces.VERBOSE_MSG) -end - -const NON_SOLVER_MESSAGE = """ - The arguments to solve are incorrect. - The second argument must be a solver choice, `solve(prob,alg)` - where `alg` is a `<: AbstractDEAlgorithm`, e.g. `Tsit5()`. - - Please double check the arguments being sent to the solver. - - You can find the list of available solvers at https://diffeq.sciml.ai/stable/solvers/ode_solve/ - and its associated pages. - """ - -struct NonSolverError <: Exception end - -function Base.showerror(io::IO, e::NonSolverError) - print(io, NON_SOLVER_MESSAGE) - println(io, TruncatedStacktraces.VERBOSE_MSG) -end - -const NOISE_SIZE_MESSAGE = """ - Noise sizes are incompatible. The expected number of noise terms in the defined - `noise_rate_prototype` does not match the number of noise terms in the defined - `AbstractNoiseProcess`. Please ensure that - size(prob.noise_rate_prototype,2) == length(prob.noise.W[1]). - - Note: Noise process definitions require that users specify `u0`, and this value is - directly used in the definition. For example, if `noise = WienerProcess(0.0,0.0)`, - then the noise process is a scalar with `u0=0.0`. If `noise = WienerProcess(0.0,[0.0])`, - then the noise process is a vector with `u0=0.0`. If `noise_rate_prototype = zeros(2,4)`, - then the noise process must be a 4-dimensional process, for example - `noise = WienerProcess(0.0,zeros(4))`. This error is a sign that the user definition - of `noise_rate_prototype` and `noise` are not aligned in this manner and the definitions should - be double checked. - """ - -struct NoiseSizeIncompatabilityError <: Exception - prototypesize::Int - noisesize::Int -end - -function Base.showerror(io::IO, e::NoiseSizeIncompatabilityError) - println(io, NOISE_SIZE_MESSAGE) - println(io, "size(prob.noise_rate_prototype,2) = $(e.prototypesize)") - println(io, "length(prob.noise.W[1]) = $(e.noisesize)") - println(io, TruncatedStacktraces.VERBOSE_MSG) -end - -const PROBSOLVER_PAIRING_MESSAGE = """ - Incompatible problem+solver pairing. - For example, this can occur if an ODE solver is passed with an SDEProblem. - Solvers are only capable of handling specific problem types. Please double - check that the chosen pairing is capable for handling the given problems. - """ - -struct ProblemSolverPairingError <: Exception - prob::Any - alg::Any -end - -function Base.showerror(io::IO, e::ProblemSolverPairingError) - println(io, PROBSOLVER_PAIRING_MESSAGE) - println(io, "Problem type: $(SciMLBase.__parameterless_type(typeof(e.prob)))") - println(io, "Solver type: $(SciMLBase.__parameterless_type(typeof(e.alg)))") - println(io, - "Problem types compatible with the chosen solver: $(compatible_problem_types(e.prob,e.alg))") - println(io, TruncatedStacktraces.VERBOSE_MSG) -end - -function compatible_problem_types(prob, alg) - if alg isa AbstractODEAlgorithm - ODEProblem - elseif alg isa AbstractSDEAlgorithm - (SDEProblem, SDDEProblem) - elseif alg isa AbstractDDEAlgorithm # StochasticDelayDiffEq.jl just uses the SDE alg - DDEProblem - elseif alg isa AbstractDAEAlgorithm - DAEProblem - elseif alg isa AbstractSteadyStateAlgorithm - SteadyStateProblem - end -end - -const DIRECT_AUTODIFF_INCOMPATABILITY_MESSAGE = """ - Incompatible solver + automatic differentiation pairing. - The chosen automatic differentiation algorithm requires the ability - for compiler transforms on the code which is only possible on pure-Julia - solvers such as those from OrdinaryDiffEq.jl. Direct differentiation methods - which require this ability include: - - - Direct use of ForwardDiff.jl on the solver - - `ForwardDiffSensitivity`, `ReverseDiffAdjoint`, `TrackerAdjoint`, and `ZygoteAdjoint` - sensealg choices for adjoint differentiation. - - Either switch the choice of solver to a pure Julia method, or change the automatic - differentiation method to one that does not require such transformations. - - For more details on automatic differentiation, adjoint, and sensitivity analysis - of differential equations, see the documentation page: - - https://diffeq.sciml.ai/stable/analysis/sensitivity/ - """ - -struct DirectAutodiffError <: Exception end - -function Base.showerror(io::IO, e::DirectAutodiffError) - println(io, DIRECT_AUTODIFF_INCOMPATABILITY_MESSAGE) - println(io, TruncatedStacktraces.VERBOSE_MSG) -end - -const NONCONCRETE_ELTYPE_MESSAGE = """ - Non-concrete element type inside of an `Array` detected. - Arrays with non-concrete element types, such as - `Array{Union{Float32,Float64}}`, are not supported by the - differential equation solvers. Anyways, this is bad for - performance so you don't want to be doing this! - - If this was a mistake, promote the element types to be - all the same. If this was intentional, for example, - using Unitful.jl with different unit values, then use - an array type which has fast broadcast support for - heterogeneous values such as the ArrayPartition - from RecursiveArrayTools.jl. For example: - - ```julia - using RecursiveArrayTools - x = ArrayPartition([1.0,2.0],[1f0,2f0]) - y = ArrayPartition([3.0,4.0],[3f0,4f0]) - x .+ y # fast, stable, and usable as u0 into DiffEq! - ``` - - Element type: - """ - -struct NonConcreteEltypeError <: Exception - eltype::Any -end - -function Base.showerror(io::IO, e::NonConcreteEltypeError) - print(io, NONCONCRETE_ELTYPE_MESSAGE) - print(io, e.eltype) - println(io, TruncatedStacktraces.VERBOSE_MSG) -end - -const NONNUMBER_ELTYPE_MESSAGE = """ - Non-Number element type inside of an `Array` detected. - Arrays with non-number element types, such as - `Array{Array{Float64}}`, are not supported by the - solvers. - - If you are trying to use an array of arrays structure, - look at the tools in RecursiveArrayTools.jl. For example: - - If this was a mistake, promote the element types to be - all the same. If this was intentional, for example, - using Unitful.jl with different unit values, then use - an array type which has fast broadcast support for - heterogeneous values such as the ArrayPartition - from RecursiveArrayTools.jl. For example: - - ```julia - using RecursiveArrayTools - u0 = ArrayPartition([1.0,2.0],[3.0,4.0]) - u0 = VectorOfArray([1.0,2.0],[3.0,4.0]) - ``` - - are both initial conditions which would be compatible with - the solvers. Or use ComponentArrays.jl for more complex - nested structures. - - Element type: - """ - -struct NonNumberEltypeError <: Exception - eltype::Any -end - -function Base.showerror(io::IO, e::NonNumberEltypeError) - print(io, NONNUMBER_ELTYPE_MESSAGE) - print(io, e.eltype) - println(io, TruncatedStacktraces.VERBOSE_MSG) -end - -const GENERIC_NUMBER_TYPE_ERROR_MESSAGE = """ - Non-standard number type (i.e. not Float32, Float64, - ComplexF32, or ComplexF64) detected as the element type - for the initial condition or time span. These generic - number types are only compatible with the pure Julia - solvers which support generic programming, such as - OrdinaryDiffEq.jl. The chosen solver does not support - this functionality. Please double check that the initial - condition and time span types are correct, and check that - the chosen solver was correct. - """ - -struct GenericNumberTypeError <: Exception - alg::Any - uType::Any - tType::Any -end - -function Base.showerror(io::IO, e::GenericNumberTypeError) - println(io, GENERIC_NUMBER_TYPE_ERROR_MESSAGE) - println(io, "Solver: $(e.alg)") - println(io, "u0 type: $(e.uType)") - print(io, "Timespan type: $(e.tType)") - println(io, TruncatedStacktraces.VERBOSE_MSG) -end - -const COMPLEX_SUPPORT_ERROR_MESSAGE = """ - Complex number type (i.e. ComplexF32, or ComplexF64) - detected as the element type for the initial condition - with an algorithm that does not support complex numbers. - Please check that the initial condition type is correct. - If complex number support is needed, try different solvers - such as those from OrdinaryDiffEq.jl. - """ - -struct ComplexSupportError <: Exception - alg::Any -end - -function Base.showerror(io::IO, e::ComplexSupportError) - println(io, COMPLEX_SUPPORT_ERROR_MESSAGE) - println(io, "Solver: $(e.alg)") - println(io, TruncatedStacktraces.VERBOSE_MSG) -end - -const COMPLEX_TSPAN_ERROR_MESSAGE = """ - Complex number type (i.e. ComplexF32, or ComplexF64) - detected as the element type for the independent variable - (i.e. time span). Please check that the tspan type is correct. - No solvers support complex time spans. If this is required, - please open an issue. - """ - -struct ComplexTspanError <: Exception end - -function Base.showerror(io::IO, e::ComplexTspanError) - println(io, COMPLEX_TSPAN_ERROR_MESSAGE) - println(io, TruncatedStacktraces.VERBOSE_MSG) -end - -const TUPLE_STATE_ERROR_MESSAGE = """ - Tuple type used as a state. Since a tuple does not have vector - properties, it will not work as a state type in equation solvers. - Instead, change your equation from using tuple constructors `()` - to static array constructors `SA[]`. For example, change: - - ```julia - function ftup((a,b),p,t) - return b,-a - end - u0 = (1.0,2.0) - tspan = (0.0,1.0) - ODEProblem(ftup,u0,tspan) - ``` - - to: - - ```julia - using StaticArrays - function fsa(u,p,t) - SA[u[2],u[1]] - end - u0 = SA[1.0,2.0] - tspan = (0.0,1.0) - ODEProblem(ftup,u0,tspan) - ``` - - This will be safer and fast for small ODEs. For more information, see: - https://diffeq.sciml.ai/stable/tutorials/faster_ode_example/#Further-Optimizations-of-Small-Non-Stiff-ODEs-with-StaticArrays - """ - -struct TupleStateError <: Exception end - -function Base.showerror(io::IO, e::TupleStateError) - println(io, TUPLE_STATE_ERROR_MESSAGE) - println(io, TruncatedStacktraces.VERBOSE_MSG) -end - -const MASS_MATRIX_ERROR_MESSAGE = """ - Mass matrix size is incompatible with initial condition - sizing. The mass matrix must represent the `vec` - form of the initial condition `u0`, i.e. - `size(mm,1) == size(mm,2) == length(u)` - """ - -struct IncompatibleMassMatrixError <: Exception - sz::Int - len::Int -end - -function Base.showerror(io::IO, e::IncompatibleMassMatrixError) - println(io, MASS_MATRIX_ERROR_MESSAGE) - print(io, "size(prob.f.mass_matrix,1): ") - println(io, e.sz) - print(io, "length(u0): ") - println(e.len) - println(io, TruncatedStacktraces.VERBOSE_MSG) -end - -const LATE_BINDING_TSTOPS_ERROR_MESSAGE = """ - This solver does not support providing `tstops` as a function. - Consider using a different solver or providing `tstops` as an array - of times. - """ - -struct LateBindingTstopsNotSupportedError <: Exception end - -function Base.showerror(io::IO, e::LateBindingTstopsNotSupportedError) - println(io, LATE_BINDING_TSTOPS_ERROR_MESSAGE) - println(io, TruncatedStacktraces.VERBOSE_MSG) -end - -""" - $(TYPEDSIGNATURES) - -Given the index provider `indp` used to construct the problem `prob` being solved, return -an updated `prob` to be used for solving. All implementations should accept arbitrary -keyword arguments. - -Should be called before the problem is solved, after performing type-promotion on the -problem. If the returned problem is not `===` the provided `prob`, it is assumed to -contain the `u0` and `p` passed as keyword arguments. - -# Keyword Arguments - -- `u0`, `p`: Override values for `state_values(prob)` and `parameter_values(prob)` which - should be used instead of the ones in `prob`. -""" -function get_updated_symbolic_problem(indp, prob; kw...) - return prob -end +# has_kwargs(_prob::AbstractDEProblem) = has_kwargs(typeof(_prob)) +# Base.@pure __has_kwargs(::Type{T}) where {T} = :kwargs ∈ fieldnames(T) +# has_kwargs(::Type{T}) where {T} = __has_kwargs(T) + +# const allowedkeywords = (:dense, +# :saveat, +# :save_idxs, +# :tstops, +# :tspan, +# :d_discontinuities, +# :save_everystep, +# :save_on, +# :save_start, +# :save_end, +# :initialize_save, +# :adaptive, +# :abstol, +# :reltol, +# :dt, +# :dtmax, +# :dtmin, +# :force_dtmin, +# :internalnorm, +# :controller, +# :gamma, +# :beta1, +# :beta2, +# :qmax, +# :qmin, +# :qsteady_min, +# :qsteady_max, +# :qoldinit, +# :failfactor, +# :calck, +# :alias_u0, +# :maxiters, +# :maxtime, +# :callback, +# :isoutofdomain, +# :unstable_check, +# :verbose, +# :merge_callbacks, +# :progress, +# :progress_steps, +# :progress_name, +# :progress_message, +# :progress_id, +# :timeseries_errors, +# :dense_errors, +# :weak_timeseries_errors, +# :weak_dense_errors, +# :wrap, +# :calculate_error, +# :initializealg, +# :alg, +# :save_noise, +# :delta, +# :seed, +# :alg_hints, +# :kwargshandle, +# :trajectories, +# :batch_size, +# :sensealg, +# :advance_to_tstop, +# :stop_at_next_tstop, +# :u0, +# :p, +# # These two are from the default algorithm handling +# :default_set, +# :second_time, +# # This is for DiffEqDevTools +# :prob_choice, +# # Jump problems +# :alias_jump, +# # This is for copying/deepcopying noise in StochasticDiffEq +# :alias_noise, +# # This is for SimpleNonlinearSolve handling for batched Nonlinear Solves +# :batch, +# # Shooting method in BVP needs to differentiate between these two categories +# :nlsolve_kwargs, +# :odesolve_kwargs, +# # If Solvers which internally use linsolve +# :linsolve_kwargs, +# # Solvers internally using EnsembleProblem +# :ensemblealg, +# # Fine Grained Control of Tracing (Storing and Logging) during Solve +# :show_trace, +# :trace_level, +# :store_trace, +# # Termination condition for solvers +# :termination_condition, +# # For AbstractAliasSpecifier +# :alias, +# # Parameter estimation with BVP +# :fit_parameters) + +# const KWARGWARN_MESSAGE = """ +# Unrecognized keyword arguments found. +# The only allowed keyword arguments to `solve` are: +# $allowedkeywords + +# See https://diffeq.sciml.ai/stable/basics/common_solver_opts/ for more details. + +# Set kwargshandle=KeywordArgError for an error message. +# Set kwargshandle=KeywordArgSilent to ignore this message. +# """ + +# const KWARGERROR_MESSAGE = """ +# Unrecognized keyword arguments found. +# The only allowed keyword arguments to `solve` are: +# $allowedkeywords + +# See https://diffeq.sciml.ai/stable/basics/common_solver_opts/ for more details. +# """ + +# struct CommonKwargError <: Exception +# kwargs::Any +# end + +# function Base.showerror(io::IO, e::CommonKwargError) +# println(io, KWARGERROR_MESSAGE) +# notin = collect(map(x -> x ∉ allowedkeywords, keys(e.kwargs))) +# unrecognized = collect(keys(e.kwargs))[notin] +# print(io, "Unrecognized keyword arguments: ") +# printstyled(io, unrecognized; bold = true, color = :red) +# print(io, "\n\n") +# println(io, TruncatedStacktraces.VERBOSE_MSG) +# end + +# @enum KeywordArgError KeywordArgWarn KeywordArgSilent + +# const INCOMPATIBLE_U0_MESSAGE = """ +# Initial condition incompatible with functional form. +# Detected an in-place function with an initial condition of type Number or SArray. +# This is incompatible because Numbers cannot be mutated, i.e. +# `x = 2.0; y = 2.0; x .= y` will error. + +# If using a immutable initial condition type, please use the out-of-place form. +# I.e. define the function `du=f(u,p,t)` instead of attempting to "mutate" the immutable `du`. + +# If your differential equation function was defined with multiple dispatches and one is +# in-place, then the automatic detection will choose in-place. In this case, override the +# choice in the problem constructor, i.e. `ODEProblem{false}(f,u0,tspan,p,kwargs...)`. + +# For a longer discussion on mutability vs immutability and in-place vs out-of-place, see: +# https://diffeq.sciml.ai/stable/tutorials/faster_ode_example/#Example-Accelerating-a-Non-Stiff-Equation:-The-Lorenz-Equation +# """ + +# struct IncompatibleInitialConditionError <: Exception end + +# function Base.showerror(io::IO, e::IncompatibleInitialConditionError) +# print(io, INCOMPATIBLE_U0_MESSAGE) +# println(io, TruncatedStacktraces.VERBOSE_MSG) +# end + +# const NO_DEFAULT_ALGORITHM_MESSAGE = """ +# Default algorithm choices require DifferentialEquations.jl. +# Please specify an algorithm (e.g., `solve(prob, Tsit5())` or +# `init(prob, Tsit5())` for an ODE) or import DifferentialEquations +# directly. + +# You can find the list of available solvers at https://diffeq.sciml.ai/stable/solvers/ode_solve/ +# and its associated pages. +# """ + +# struct NoDefaultAlgorithmError <: Exception end + +# function Base.showerror(io::IO, e::NoDefaultAlgorithmError) +# print(io, NO_DEFAULT_ALGORITHM_MESSAGE) +# println(io, TruncatedStacktraces.VERBOSE_MSG) +# end + +# const NO_TSPAN_MESSAGE = """ +# No tspan is set in the problem or chosen in the init/solve call +# """ + +# struct NoTspanError <: Exception end + +# function Base.showerror(io::IO, e::NoTspanError) +# print(io, NO_TSPAN_MESSAGE) +# println(io, TruncatedStacktraces.VERBOSE_MSG) +# end + +# const NAN_TSPAN_MESSAGE = """ +# NaN tspan is set in the problem or chosen in the init/solve call. +# Note that -Inf and Inf values are allowed in the timespan for solves +# which are terminated via callbacks, however NaN values are not allowed +# since the direction of time is undetermined. +# """ + +# struct NaNTspanError <: Exception end + +# function Base.showerror(io::IO, e::NaNTspanError) +# print(io, NAN_TSPAN_MESSAGE) +# println(io, TruncatedStacktraces.VERBOSE_MSG) +# end + +# const NON_SOLVER_MESSAGE = """ +# The arguments to solve are incorrect. +# The second argument must be a solver choice, `solve(prob,alg)` +# where `alg` is a `<: AbstractDEAlgorithm`, e.g. `Tsit5()`. + +# Please double check the arguments being sent to the solver. + +# You can find the list of available solvers at https://diffeq.sciml.ai/stable/solvers/ode_solve/ +# and its associated pages. +# """ + +# struct NonSolverError <: Exception end + +# function Base.showerror(io::IO, e::NonSolverError) +# print(io, NON_SOLVER_MESSAGE) +# println(io, TruncatedStacktraces.VERBOSE_MSG) +# end + +# const NOISE_SIZE_MESSAGE = """ +# Noise sizes are incompatible. The expected number of noise terms in the defined +# `noise_rate_prototype` does not match the number of noise terms in the defined +# `AbstractNoiseProcess`. Please ensure that +# size(prob.noise_rate_prototype,2) == length(prob.noise.W[1]). + +# Note: Noise process definitions require that users specify `u0`, and this value is +# directly used in the definition. For example, if `noise = WienerProcess(0.0,0.0)`, +# then the noise process is a scalar with `u0=0.0`. If `noise = WienerProcess(0.0,[0.0])`, +# then the noise process is a vector with `u0=0.0`. If `noise_rate_prototype = zeros(2,4)`, +# then the noise process must be a 4-dimensional process, for example +# `noise = WienerProcess(0.0,zeros(4))`. This error is a sign that the user definition +# of `noise_rate_prototype` and `noise` are not aligned in this manner and the definitions should +# be double checked. +# """ + +# struct NoiseSizeIncompatabilityError <: Exception +# prototypesize::Int +# noisesize::Int +# end + +# function Base.showerror(io::IO, e::NoiseSizeIncompatabilityError) +# println(io, NOISE_SIZE_MESSAGE) +# println(io, "size(prob.noise_rate_prototype,2) = $(e.prototypesize)") +# println(io, "length(prob.noise.W[1]) = $(e.noisesize)") +# println(io, TruncatedStacktraces.VERBOSE_MSG) +# end + +# const PROBSOLVER_PAIRING_MESSAGE = """ +# Incompatible problem+solver pairing. +# For example, this can occur if an ODE solver is passed with an SDEProblem. +# Solvers are only capable of handling specific problem types. Please double +# check that the chosen pairing is capable for handling the given problems. +# """ + +# struct ProblemSolverPairingError <: Exception +# prob::Any +# alg::Any +# end + +# function Base.showerror(io::IO, e::ProblemSolverPairingError) +# println(io, PROBSOLVER_PAIRING_MESSAGE) +# println(io, "Problem type: $(SciMLBase.__parameterless_type(typeof(e.prob)))") +# println(io, "Solver type: $(SciMLBase.__parameterless_type(typeof(e.alg)))") +# println(io, +# "Problem types compatible with the chosen solver: $(compatible_problem_types(e.prob,e.alg))") +# println(io, TruncatedStacktraces.VERBOSE_MSG) +# end + +# function compatible_problem_types(prob, alg) +# if alg isa AbstractODEAlgorithm +# ODEProblem +# elseif alg isa AbstractSDEAlgorithm +# (SDEProblem, SDDEProblem) +# elseif alg isa AbstractDDEAlgorithm # StochasticDelayDiffEq.jl just uses the SDE alg +# DDEProblem +# elseif alg isa AbstractDAEAlgorithm +# DAEProblem +# elseif alg isa AbstractSteadyStateAlgorithm +# SteadyStateProblem +# end +# end + +# const DIRECT_AUTODIFF_INCOMPATABILITY_MESSAGE = """ +# Incompatible solver + automatic differentiation pairing. +# The chosen automatic differentiation algorithm requires the ability +# for compiler transforms on the code which is only possible on pure-Julia +# solvers such as those from OrdinaryDiffEq.jl. Direct differentiation methods +# which require this ability include: + +# - Direct use of ForwardDiff.jl on the solver +# - `ForwardDiffSensitivity`, `ReverseDiffAdjoint`, `TrackerAdjoint`, and `ZygoteAdjoint` +# sensealg choices for adjoint differentiation. + +# Either switch the choice of solver to a pure Julia method, or change the automatic +# differentiation method to one that does not require such transformations. + +# For more details on automatic differentiation, adjoint, and sensitivity analysis +# of differential equations, see the documentation page: + +# https://diffeq.sciml.ai/stable/analysis/sensitivity/ +# """ + +# struct DirectAutodiffError <: Exception end + +# function Base.showerror(io::IO, e::DirectAutodiffError) +# println(io, DIRECT_AUTODIFF_INCOMPATABILITY_MESSAGE) +# println(io, TruncatedStacktraces.VERBOSE_MSG) +# end + +# const NONCONCRETE_ELTYPE_MESSAGE = """ +# Non-concrete element type inside of an `Array` detected. +# Arrays with non-concrete element types, such as +# `Array{Union{Float32,Float64}}`, are not supported by the +# differential equation solvers. Anyways, this is bad for +# performance so you don't want to be doing this! + +# If this was a mistake, promote the element types to be +# all the same. If this was intentional, for example, +# using Unitful.jl with different unit values, then use +# an array type which has fast broadcast support for +# heterogeneous values such as the ArrayPartition +# from RecursiveArrayTools.jl. For example: + +# ```julia +# using RecursiveArrayTools +# x = ArrayPartition([1.0,2.0],[1f0,2f0]) +# y = ArrayPartition([3.0,4.0],[3f0,4f0]) +# x .+ y # fast, stable, and usable as u0 into DiffEq! +# ``` + +# Element type: +# """ + +# struct NonConcreteEltypeError <: Exception +# eltype::Any +# end + +# function Base.showerror(io::IO, e::NonConcreteEltypeError) +# print(io, NONCONCRETE_ELTYPE_MESSAGE) +# print(io, e.eltype) +# println(io, TruncatedStacktraces.VERBOSE_MSG) +# end + +# const NONNUMBER_ELTYPE_MESSAGE = """ +# Non-Number element type inside of an `Array` detected. +# Arrays with non-number element types, such as +# `Array{Array{Float64}}`, are not supported by the +# solvers. + +# If you are trying to use an array of arrays structure, +# look at the tools in RecursiveArrayTools.jl. For example: + +# If this was a mistake, promote the element types to be +# all the same. If this was intentional, for example, +# using Unitful.jl with different unit values, then use +# an array type which has fast broadcast support for +# heterogeneous values such as the ArrayPartition +# from RecursiveArrayTools.jl. For example: + +# ```julia +# using RecursiveArrayTools +# u0 = ArrayPartition([1.0,2.0],[3.0,4.0]) +# u0 = VectorOfArray([1.0,2.0],[3.0,4.0]) +# ``` + +# are both initial conditions which would be compatible with +# the solvers. Or use ComponentArrays.jl for more complex +# nested structures. + +# Element type: +# """ + +# struct NonNumberEltypeError <: Exception +# eltype::Any +# end + +# function Base.showerror(io::IO, e::NonNumberEltypeError) +# print(io, NONNUMBER_ELTYPE_MESSAGE) +# print(io, e.eltype) +# println(io, TruncatedStacktraces.VERBOSE_MSG) +# end + +# const GENERIC_NUMBER_TYPE_ERROR_MESSAGE = """ +# Non-standard number type (i.e. not Float32, Float64, +# ComplexF32, or ComplexF64) detected as the element type +# for the initial condition or time span. These generic +# number types are only compatible with the pure Julia +# solvers which support generic programming, such as +# OrdinaryDiffEq.jl. The chosen solver does not support +# this functionality. Please double check that the initial +# condition and time span types are correct, and check that +# the chosen solver was correct. +# """ + +# struct GenericNumberTypeError <: Exception +# alg::Any +# uType::Any +# tType::Any +# end + +# function Base.showerror(io::IO, e::GenericNumberTypeError) +# println(io, GENERIC_NUMBER_TYPE_ERROR_MESSAGE) +# println(io, "Solver: $(e.alg)") +# println(io, "u0 type: $(e.uType)") +# print(io, "Timespan type: $(e.tType)") +# println(io, TruncatedStacktraces.VERBOSE_MSG) +# end + +# const COMPLEX_SUPPORT_ERROR_MESSAGE = """ +# Complex number type (i.e. ComplexF32, or ComplexF64) +# detected as the element type for the initial condition +# with an algorithm that does not support complex numbers. +# Please check that the initial condition type is correct. +# If complex number support is needed, try different solvers +# such as those from OrdinaryDiffEq.jl. +# """ + +# struct ComplexSupportError <: Exception +# alg::Any +# end + +# function Base.showerror(io::IO, e::ComplexSupportError) +# println(io, COMPLEX_SUPPORT_ERROR_MESSAGE) +# println(io, "Solver: $(e.alg)") +# println(io, TruncatedStacktraces.VERBOSE_MSG) +# end + +# const COMPLEX_TSPAN_ERROR_MESSAGE = """ +# Complex number type (i.e. ComplexF32, or ComplexF64) +# detected as the element type for the independent variable +# (i.e. time span). Please check that the tspan type is correct. +# No solvers support complex time spans. If this is required, +# please open an issue. +# """ + +# struct ComplexTspanError <: Exception end + +# function Base.showerror(io::IO, e::ComplexTspanError) +# println(io, COMPLEX_TSPAN_ERROR_MESSAGE) +# println(io, TruncatedStacktraces.VERBOSE_MSG) +# end + +# const TUPLE_STATE_ERROR_MESSAGE = """ +# Tuple type used as a state. Since a tuple does not have vector +# properties, it will not work as a state type in equation solvers. +# Instead, change your equation from using tuple constructors `()` +# to static array constructors `SA[]`. For example, change: + +# ```julia +# function ftup((a,b),p,t) +# return b,-a +# end +# u0 = (1.0,2.0) +# tspan = (0.0,1.0) +# ODEProblem(ftup,u0,tspan) +# ``` + +# to: + +# ```julia +# using StaticArrays +# function fsa(u,p,t) +# SA[u[2],u[1]] +# end +# u0 = SA[1.0,2.0] +# tspan = (0.0,1.0) +# ODEProblem(ftup,u0,tspan) +# ``` + +# This will be safer and fast for small ODEs. For more information, see: +# https://diffeq.sciml.ai/stable/tutorials/faster_ode_example/#Further-Optimizations-of-Small-Non-Stiff-ODEs-with-StaticArrays +# """ + +# struct TupleStateError <: Exception end + +# function Base.showerror(io::IO, e::TupleStateError) +# println(io, TUPLE_STATE_ERROR_MESSAGE) +# println(io, TruncatedStacktraces.VERBOSE_MSG) +# end + +# const MASS_MATRIX_ERROR_MESSAGE = """ +# Mass matrix size is incompatible with initial condition +# sizing. The mass matrix must represent the `vec` +# form of the initial condition `u0`, i.e. +# `size(mm,1) == size(mm,2) == length(u)` +# """ + +# struct IncompatibleMassMatrixError <: Exception +# sz::Int +# len::Int +# end + +# function Base.showerror(io::IO, e::IncompatibleMassMatrixError) +# println(io, MASS_MATRIX_ERROR_MESSAGE) +# print(io, "size(prob.f.mass_matrix,1): ") +# println(io, e.sz) +# print(io, "length(u0): ") +# println(e.len) +# println(io, TruncatedStacktraces.VERBOSE_MSG) +# end + +# const LATE_BINDING_TSTOPS_ERROR_MESSAGE = """ +# This solver does not support providing `tstops` as a function. +# Consider using a different solver or providing `tstops` as an array +# of times. +# """ + +# struct LateBindingTstopsNotSupportedError <: Exception end + +# function Base.showerror(io::IO, e::LateBindingTstopsNotSupportedError) +# println(io, LATE_BINDING_TSTOPS_ERROR_MESSAGE) +# println(io, TruncatedStacktraces.VERBOSE_MSG) +# end + +# """ +# $(TYPEDSIGNATURES) + +# Given the index provider `indp` used to construct the problem `prob` being solved, return +# an updated `prob` to be used for solving. All implementations should accept arbitrary +# keyword arguments. + +# Should be called before the problem is solved, after performing type-promotion on the +# problem. If the returned problem is not `===` the provided `prob`, it is assumed to +# contain the `u0` and `p` passed as keyword arguments. + +# # Keyword Arguments + +# - `u0`, `p`: Override values for `state_values(prob)` and `parameter_values(prob)` which +# should be used instead of the ones in `prob`. +# """ +# function get_updated_symbolic_problem(indp, prob; kw...) +# return prob +# end function init_call(_prob, args...; merge_callbacks = true, kwargshandle = nothing, kwargs...) - kwargshandle = kwargshandle === nothing ? KeywordArgError : kwargshandle + kwargshandle = kwargshandle === nothing ? SciMLBase.KeywordArgError : kwargshandle kwargshandle = has_kwargs(_prob) && haskey(_prob.kwargs, :kwargshandle) ? _prob.kwargs[:kwargshandle] : kwargshandle @@ -613,7 +613,7 @@ end function solve_call(_prob, args...; merge_callbacks = true, kwargshandle = nothing, kwargs...) - kwargshandle = kwargshandle === nothing ? KeywordArgError : kwargshandle + kwargshandle = kwargshandle === nothing ? SciMLBase.KeywordArgError : kwargshandle kwargshandle = has_kwargs(_prob) && haskey(_prob.kwargs, :kwargshandle) ? _prob.kwargs[:kwargshandle] : kwargshandle @@ -1240,21 +1240,21 @@ function solve(prob::AbstractJumpProblem, args...; kwargs...) __solve(prob, args...; kwargs...) end -function checkkwargs(kwargshandle; kwargs...) - if any(x -> x ∉ allowedkeywords, keys(kwargs)) - if kwargshandle == KeywordArgError - throw(CommonKwargError(kwargs)) - elseif kwargshandle == KeywordArgWarn - @warn KWARGWARN_MESSAGE - unrecognized = setdiff(keys(kwargs), allowedkeywords) - print("Unrecognized keyword arguments: ") - printstyled(unrecognized; bold = true, color = :red) - print("\n\n") - else - @assert kwargshandle == KeywordArgSilent - end - end -end +# function checkkwargs(kwargshandle; kwargs...) +# if any(x -> x ∉ allowedkeywords, keys(kwargs)) +# if kwargshandle == KeywordArgError +# throw(CommonKwargError(kwargs)) +# elseif kwargshandle == KeywordArgWarn +# @warn KWARGWARN_MESSAGE +# unrecognized = setdiff(keys(kwargs), allowedkeywords) +# print("Unrecognized keyword arguments: ") +# printstyled(unrecognized; bold = true, color = :red) +# print("\n\n") +# else +# @assert kwargshandle == KeywordArgSilent +# end +# end +# end function get_concrete_problem(prob::AbstractJumpProblem, isadapt; kwargs...) get_updated_symbolic_problem(SciMLBase.get_root_indp(prob), prob; kwargs...) @@ -1441,99 +1441,99 @@ function get_concrete_tspan(prob, isadapt, kwargs, p) tspan end -function isconcreteu0(prob, t0, kwargs) - !eval_u0(prob.u0) && prob.u0 !== nothing && !isdistribution(prob.u0) -end +# function isconcreteu0(prob, t0, kwargs) +# !eval_u0(prob.u0) && prob.u0 !== nothing && !isdistribution(prob.u0) +# end -function isconcretedu0(prob, t0, kwargs) - !eval_u0(prob.u0) && prob.du0 !== nothing && !isdistribution(prob.du0) -end +# function isconcretedu0(prob, t0, kwargs) +# !eval_u0(prob.u0) && prob.du0 !== nothing && !isdistribution(prob.du0) +# end -function get_concrete_u0(prob, isadapt, t0, kwargs) - if eval_u0(prob.u0) - u0 = prob.u0(prob.p, t0) - elseif haskey(kwargs, :u0) - u0 = kwargs[:u0] - else - u0 = prob.u0 - end +# function get_concrete_u0(prob, isadapt, t0, kwargs) +# if eval_u0(prob.u0) +# u0 = prob.u0(prob.p, t0) +# elseif haskey(kwargs, :u0) +# u0 = kwargs[:u0] +# else +# u0 = prob.u0 +# end - isadapt && eltype(u0) <: Integer && (u0 = float.(u0)) +# isadapt && eltype(u0) <: Integer && (u0 = float.(u0)) - _u0 = handle_distribution_u0(u0) +# _u0 = handle_distribution_u0(u0) - if isinplace(prob) && (_u0 isa Number || _u0 isa SArray) - throw(IncompatibleInitialConditionError()) - end +# if isinplace(prob) && (_u0 isa Number || _u0 isa SArray) +# throw(IncompatibleInitialConditionError()) +# end - nu0 = length(something(_u0, ())) - if isdefined(prob.f, :mass_matrix) && prob.f.mass_matrix !== nothing && - prob.f.mass_matrix isa AbstractArray && - size(prob.f.mass_matrix, 1) !== nu0 - throw(IncompatibleMassMatrixError(size(prob.f.mass_matrix, 1), nu0)) - end +# nu0 = length(something(_u0, ())) +# if isdefined(prob.f, :mass_matrix) && prob.f.mass_matrix !== nothing && +# prob.f.mass_matrix isa AbstractArray && +# size(prob.f.mass_matrix, 1) !== nu0 +# throw(IncompatibleMassMatrixError(size(prob.f.mass_matrix, 1), nu0)) +# end - if _u0 isa Tuple - throw(TupleStateError()) - end +# if _u0 isa Tuple +# throw(TupleStateError()) +# end - _u0 -end +# _u0 +# end -function get_concrete_u0(prob::BVProblem, isadapt, t0, kwargs) - if haskey(kwargs, :u0) - u0 = kwargs[:u0] - else - u0 = prob.u0 - end +# function get_concrete_u0(prob::BVProblem, isadapt, t0, kwargs) +# if haskey(kwargs, :u0) +# u0 = kwargs[:u0] +# else +# u0 = prob.u0 +# end - isadapt && eltype(u0) <: Integer && (u0 = float.(u0)) +# isadapt && eltype(u0) <: Integer && (u0 = float.(u0)) - _u0 = handle_distribution_u0(u0) +# _u0 = handle_distribution_u0(u0) - if isinplace(prob) && (_u0 isa Number || _u0 isa SArray) - throw(IncompatibleInitialConditionError()) - end +# if isinplace(prob) && (_u0 isa Number || _u0 isa SArray) +# throw(IncompatibleInitialConditionError()) +# end - if _u0 isa Tuple - throw(TupleStateError()) - end +# if _u0 isa Tuple +# throw(TupleStateError()) +# end - return _u0 -end +# return _u0 +# end -function get_concrete_du0(prob, isadapt, t0, kwargs) - if eval_u0(prob.du0) - du0 = prob.du0(prob.p, t0) - elseif haskey(kwargs, :du0) - du0 = kwargs[:du0] - else - du0 = prob.du0 - end +# function get_concrete_du0(prob, isadapt, t0, kwargs) +# if eval_u0(prob.du0) +# du0 = prob.du0(prob.p, t0) +# elseif haskey(kwargs, :du0) +# du0 = kwargs[:du0] +# else +# du0 = prob.du0 +# end - isadapt && eltype(du0) <: Integer && (du0 = float.(du0)) +# isadapt && eltype(du0) <: Integer && (du0 = float.(du0)) - _du0 = handle_distribution_u0(du0) +# _du0 = handle_distribution_u0(du0) - if isinplace(prob) && (_du0 isa Number || _du0 isa SArray) - throw(IncompatibleInitialConditionError()) - end +# if isinplace(prob) && (_du0 isa Number || _du0 isa SArray) +# throw(IncompatibleInitialConditionError()) +# end - _du0 -end +# _du0 +# end -function get_concrete_p(prob, kwargs) - if haskey(kwargs, :p) - p = kwargs[:p] - else - p = prob.p - end -end +# function get_concrete_p(prob, kwargs) +# if haskey(kwargs, :p) +# p = kwargs[:p] +# else +# p = prob.p +# end +# end -handle_distribution_u0(_u0) = _u0 +# handle_distribution_u0(_u0) = _u0 -eval_u0(u0::Function) = true -eval_u0(u0) = false +# eval_u0(u0::Function) = true +# eval_u0(u0) = false function __solve( prob::AbstractDEProblem, args...; default_set = false, second_time = false, @@ -1626,22 +1626,22 @@ function check_prob_alg_pairing(prob, alg) end end -@inline function extract_alg(solve_args, solve_kwargs, prob_kwargs) - if isempty(solve_args) || isnothing(first(solve_args)) - if haskey(solve_kwargs, :alg) - solve_kwargs[:alg] - elseif haskey(prob_kwargs, :alg) - prob_kwargs[:alg] - else - nothing - end - elseif first(solve_args) isa SciMLBase.AbstractSciMLAlgorithm && - !(first(solve_args) isa SciMLBase.EnsembleAlgorithm) - first(solve_args) - else - nothing - end -end +# @inline function extract_alg(solve_args, solve_kwargs, prob_kwargs) +# if isempty(solve_args) || isnothing(first(solve_args)) +# if haskey(solve_kwargs, :alg) +# solve_kwargs[:alg] +# elseif haskey(prob_kwargs, :alg) +# prob_kwargs[:alg] +# else +# nothing +# end +# elseif first(solve_args) isa SciMLBase.AbstractSciMLAlgorithm && +# !(first(solve_args) isa SciMLBase.EnsembleAlgorithm) +# first(solve_args) +# else +# nothing +# end +# end ################### Differentiation diff --git a/src/utils.jl b/src/utils.jl index d5c316348..671fdac8a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,28 +1,8 @@ -# Handled in Extensions -value(x) = x -unitfulvalue(x) = x -isdistribution(u0) = false -sse(x::Number) = abs2(x) - -# Static Arrays don't support the `init` keyword argument for `sum` -@inline __sum(f::F, args...; init, kwargs...) where {F} = sum(f, args...; init, kwargs...) -@inline function __sum( - f::F, a::StaticArraysCore.StaticArray...; init, kwargs...) where {F} - return mapreduce(f, +, a...; init, kwargs...) -end - -totallength(x::Number) = 1 -totallength(x::AbstractArray) = __sum(totallength, x; init = 0) - _vec(v) = vec(v) _vec(v::Number) = v _vec(v::AbstractSciMLScalarOperator) = v _vec(v::AbstractVector) = v -_reshape(v, siz) = reshape(v, siz) -_reshape(v::Number, siz) = v -_reshape(v::AbstractSciMLScalarOperator, siz) = v - macro tight_loop_macros(ex) :($(esc(ex))) end @@ -129,64 +109,3 @@ end @inline __add_and_norm(::typeof(Base.Fix2(norm, Inf)), x, y) = __maximum_abs(+, x, y) @inline __add_and_norm(f::F, x, y) where {F} = __norm_op(f, +, x, y) -struct DualEltypeChecker{T, T2} - x::T - counter::T2 -end - -anyeltypedual(x) = anyeltypedual(x, Val{0}) -anyeltypedual(x, counter) = Any - -function promote_u0(u0, p, t0) - if SciMLStructures.isscimlstructure(p) - _p = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1] - if !isequal(_p, p) - return promote_u0(u0, _p, t0) - end - end - Tu = eltype(u0) - if isdualtype(Tu) - return u0 - end - Tp = anyeltypedual(p, Val{0}) - if Tp == Any - Tp = Tu - end - Tt = anyeltypedual(t0, Val{0}) - if Tt == Any - Tt = Tu - end - Tcommon = promote_type(Tu, Tp, Tt) - return if isdualtype(Tcommon) - Tcommon.(u0) - else - u0 - end -end - -function promote_u0(u0::AbstractArray{<:Complex}, p, t0) - if SciMLStructures.isscimlstructure(p) - _p = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1] - if !isequal(_p, p) - return promote_u0(u0, _p, t0) - end - end - Tu = real(eltype(u0)) - if isdualtype(Tu) - return u0 - end - Tp = anyeltypedual(p, Val{0}) - if Tp == Any - Tp = Tu - end - Tt = anyeltypedual(t0, Val{0}) - if Tt == Any - Tt = Tu - end - Tcommon = promote_type(eltype(u0), Tp, Tt) - return if isdualtype(real(Tcommon)) - Tcommon.(u0) - else - u0 - end -end From 9b3080f2cfd31e7d44f4200280c007e527814b6d Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 18 Aug 2025 12:35:12 -0400 Subject: [PATCH 02/17] import needed things from SciMLBase --- src/DiffEqBase.jl | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/DiffEqBase.jl b/src/DiffEqBase.jl index a6cf2fa0c..618ea35cf 100644 --- a/src/DiffEqBase.jl +++ b/src/DiffEqBase.jl @@ -94,7 +94,11 @@ using SciMLBase: @def, DEIntegrator, AbstractDEProblem, import SciMLBase: solve, init, step!, solve!, __init, __solve, update_coefficients!, update_coefficients, isadaptive, wrapfun_oop, wrapfun_iip, - unwrap_fw, promote_tspan, set_u!, set_t!, set_ut! + unwrap_fw, promote_tspan, set_u!, set_t!, set_ut!, + extract_alg, checkkwargs, has_kwargs, + eltypedual, get_updated_symbolic_problem, get_concrete_p, get_concrete_u0, promote_u0, + isconcreteu0, isconcretedu0, get_concrete_du0, _reshape, value, unitfulvalue, anyeltypedual, allowedkeywords, + sse, totallength, __sum, DualEltypeChecker import SciMLStructures @@ -107,10 +111,6 @@ import SymbolicIndexingInterface as SII ## Extension Functions -eltypedual(x) = false -promote_u0(::Nothing, p, t0) = nothing -isdualtype(::Type{T}) where {T} = false - ## Types """ @@ -167,6 +167,4 @@ export initialize!, finalize! export SensitivityADPassThrough -export KeywordArgError, KeywordArgWarn, KeywordArgSilent - end # module From 5d4da36d4090815ec971e3f9b41299c3dff0350c Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 18 Aug 2025 12:35:25 -0400 Subject: [PATCH 03/17] get rid of Distributions extension --- Project.toml | 1 - ext/DiffEqBaseDistributionsExt.jl | 8 -------- 2 files changed, 9 deletions(-) delete mode 100644 ext/DiffEqBaseDistributionsExt.jl diff --git a/Project.toml b/Project.toml index 4aaa9f6a3..5997760af 100644 --- a/Project.toml +++ b/Project.toml @@ -51,7 +51,6 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [extensions] DiffEqBaseCUDAExt = "CUDA" DiffEqBaseChainRulesCoreExt = "ChainRulesCore" -DiffEqBaseDistributionsExt = "Distributions" DiffEqBaseEnzymeExt = ["ChainRulesCore", "Enzyme"] DiffEqBaseForwardDiffExt = ["ForwardDiff"] DiffEqBaseGTPSAExt = "GTPSA" diff --git a/ext/DiffEqBaseDistributionsExt.jl b/ext/DiffEqBaseDistributionsExt.jl deleted file mode 100644 index e84a5509a..000000000 --- a/ext/DiffEqBaseDistributionsExt.jl +++ /dev/null @@ -1,8 +0,0 @@ -module DiffEqBaseDistributionsExt - -using Distributions, DiffEqBase - -DiffEqBase.handle_distribution_u0(_u0::Distributions.Sampleable) = rand(_u0) -DiffEqBase.isdistribution(_u0::Distributions.Sampleable) = true - -end From bd458cdb245056796356cfc5f66738a8c8150f6b Mon Sep 17 00:00:00 2001 From: jClugstor Date: Mon, 18 Aug 2025 12:36:27 -0400 Subject: [PATCH 04/17] excise extension things that are now in SciMLBase --- ext/DiffEqBaseForwardDiffExt.jl | 431 +-------------------- ext/DiffEqBaseGTPSAExt.jl | 3 +- ext/DiffEqBaseMeasurementsExt.jl | 12 - ext/DiffEqBaseMonteCarloMeasurementsExt.jl | 28 -- ext/DiffEqBaseReverseDiffExt.jl | 51 --- ext/DiffEqBaseTrackerExt.jl | 30 -- ext/DiffEqBaseUnitfulExt.jl | 2 +- 7 files changed, 20 insertions(+), 537 deletions(-) diff --git a/ext/DiffEqBaseForwardDiffExt.jl b/ext/DiffEqBaseForwardDiffExt.jl index b328eedbc..0fd81e0dc 100644 --- a/ext/DiffEqBaseForwardDiffExt.jl +++ b/ext/DiffEqBaseForwardDiffExt.jl @@ -6,34 +6,13 @@ using DiffEqBase: Void, FunctionWrappersWrappers, OrdinaryDiffEqTag, AbstractTimeseriesSolution, RecursiveArrayTools, reduce_tup, _promote_tspan, has_continuous_callback import DiffEqBase: hasdualpromote, wrapfun_oop, wrapfun_iip, prob2dtmin, - promote_tspan, anyeltypedual, isdualtype, value, ODE_DEFAULT_NORM, - InternalITP, nextfloat_tdir, DualEltypeChecker, sse, unitfulvalue + promote_tspan, ODE_DEFAULT_NORM, + InternalITP, nextfloat_tdir +import SciMLBase: isdualtype, DualEltypeChecker, sse, __sum -eltypedual(x) = eltype(x) <: ForwardDiff.Dual -isdualtype(::Type{<:ForwardDiff.Dual}) = true const dualT = ForwardDiff.Dual{ForwardDiff.Tag{OrdinaryDiffEqTag, Float64}, Float64, 1} dualgen(::Type{T}) where {T} = ForwardDiff.Dual{ForwardDiff.Tag{OrdinaryDiffEqTag, T}, T, 1} -# Copy of the other prob2dtmin dispatch, just for optionality -function prob2dtmin(tspan, ::ForwardDiff.Dual, use_end_time) - t1, t2 = tspan - isfinite(t1) || throw(ArgumentError("t0 in the tspan `(t0, t1)` must be finite")) - if use_end_time && isfinite(t2 - t1) - return max(eps(t2), eps(t1)) - else - return max(eps(typeof(t1)), eps(t1)) - end -end - -function hasdualpromote(u0, t::Number) - hasmethod(ArrayInterface.promote_eltype, - Tuple{Type{typeof(u0)}, Type{dualgen(eltype(u0))}}) && - hasmethod(promote_rule, - Tuple{Type{eltype(u0)}, Type{dualgen(eltype(u0))}}) && - hasmethod(promote_rule, - Tuple{Type{eltype(u0)}, Type{typeof(t)}}) -end - const NORECOMPILE_IIP_SUPPORTED_ARGS = ( Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64}, @@ -111,400 +90,24 @@ function wrapfun_iip(@nospecialize(ff)) FunctionWrappersWrappers.FunctionWrappersWrapper{typeof(fwt), false}(fwt) end -promote_dual(::Type{T}, ::Type{T2}) where {T <: ForwardDiff.Dual, T2} = T -function promote_dual(::Type{T}, - ::Type{T2}) where {T <: ForwardDiff.Dual, T2 <: ForwardDiff.Dual} - T -end -promote_dual(::Type{T}, ::Type{T2}) where {T, T2 <: ForwardDiff.Dual} = T2 - -function promote_dual(::Type{T}, - ::Type{T2}) where {T3, T4, V, V2 <: ForwardDiff.Dual, N, N2, - T <: ForwardDiff.Dual{T3, V, N}, - T2 <: ForwardDiff.Dual{T4, V2, N2}} - T2 -end -function promote_dual(::Type{T}, - ::Type{T2}) where {T3, T4, V <: ForwardDiff.Dual, V2, N, N2, - T <: ForwardDiff.Dual{T3, V, N}, - T2 <: ForwardDiff.Dual{T4, V2, N2}} - T -end -function promote_dual(::Type{T}, - ::Type{T2}) where { - T3, V <: ForwardDiff.Dual, V2 <: ForwardDiff.Dual, - N, - T <: ForwardDiff.Dual{T3, V, N}, - T2 <: ForwardDiff.Dual{T3, V2, N}} - ForwardDiff.Dual{T3, promote_dual(V, V2), N} -end - -""" - promote_dual(::Type{T},::Type{T2}) - - -Is like the number promotion system, but always prefers a dual number type above -anything else. For higher order differentiation, it returns the most dualiest of -them all. This is then used to promote `u0` into the suspected highest differentiation -space for solving the equation. -""" -promote_dual(::Type{T}, ::Type{T2}) where {T, T2} = T - -# `reduce` and `map` are specialized on tuples to be unrolled (via recursion) -# Therefore, they can be type stable even with heterogeneous input types. -# We also don't care about allocating any temporaries with them, as it should -# all be unrolled and optimized away. -# Being unrolled also means const prop can work for things like -# `mapreduce(f, op, propertynames(x))` -# where `f` may call `getproperty` and thus have return type dependent -# on the particular symbol. -# `mapreduce` hasn't received any such specialization. -@inline diffeqmapreduce(f::F, op::OP, x::Tuple) where {F, OP} = reduce_tup(op, map(f, x)) -@inline function diffeqmapreduce(f::F, op::OP, x::NamedTuple) where {F, OP} - reduce_tup(op, map(f, x)) -end -# For other container types, we probably just want to call `mapreduce` -@inline diffeqmapreduce(f::F, op::OP, x) where {F, OP} = mapreduce(f, op, x, init = Any) - -getval(::Val{I}) where {I} = I -getval(::Type{Val{I}}) where {I} = I -getval(I::Int) = I - -const DUALCHECK_RECURSION_MAX = 10 - -function (dec::DualEltypeChecker)(::Val{Y}) where {Y} - isdefined(dec.x, Y) || return Any - getval(dec.counter) >= DUALCHECK_RECURSION_MAX && return Any - anyeltypedual(getfield(dec.x, Y), Val{getval(dec.counter)}) -end - -# Untyped dispatch: catch composite types, check all of their fields -""" - anyeltypedual(x) - - -Searches through a type to see if any of its values are parameters. This is used to -then promote other values to match the dual type. For example, if a user passes a parameter - -which is a `Dual` and a `u0` which is a `Float64`, after the first time step, `f(u,p,t) = p*u` -will change `u0` from `Float64` to `Dual`. Thus the state variable always needs to be converted -to a dual number before the solve. Worse still, this needs to be done in the case of -`f(du,u,p,t) = du[1] = p*u[1]`, and thus running `f` and taking the return value is not a valid -way to calculate the required state type. - -But given the properties of automatic differentiation requiring that differentiation of parameters -implies differentiation of state, we assume any dual parameters implies differentiation of state -and then attempt to upconvert `u0` to match that dual-ness. Because this changes types, this needs -to be specified at compiled time and thus cannot have a Bool-based opt out, so in the future this -may be extended to use a preference system to opt-out with a `UPCONVERT_DUALS`. In the case where -upconversion is not done automatically, the user is required to upconvert all initial conditions -themselves, for an example of how this can be confusing to a user see -https://discourse.julialang.org/t/typeerror-in-julia-turing-when-sampling-for-a-forced-differential-equation/82937 -""" -@generated function anyeltypedual(x, ::Type{Val{counter}}) where {counter} - x = x.name === Core.Compiler.typename(Type) ? x.parameters[1] : x - if isdualtype(x) - :($x) - elseif fieldnames(x) === () - :(Any) - elseif counter < DUALCHECK_RECURSION_MAX - T = diffeqmapreduce(x -> anyeltypedual(x, Val{counter + 1}), promote_dual, - x.parameters) - if T === Any || isconcretetype(T) - :($T) - else - :(diffeqmapreduce(DualEltypeChecker($x, $counter + 1), promote_dual, - map(Val, fieldnames($(typeof(x)))))) - end - else - :(Any) - end -end - -const FORWARDDIFF_AUTODETECTION_FAILURE_MESSAGE = """ - Failed to automatically detect ForwardDiff compatability of - the parameter object. In order for ForwardDiff.jl automatic - differentiation to work on a solution object, the state of - the differential equation or nonlinear solve (`u0`) needs to - be converted to a Dual type which matches the values being - differentiated. For example, for a loss function loss(p) - where `p`` is a `Vector{Float64}`, this conversion is - equivalent to: - - ```julia - # Convert u0 to match the new Dual element type of `p` - _prob = remake(prob, u0 = eltype(p).(prob.u0)) - ``` - - In most cases, SciML tools are able to do this conversion - automatically. However, it seems you have provided a - parameter type for which this automatic conversion has failed. - - To fix this, you can do the conversion yourself. For example, - if you have a parameter vector being optimized `p` which is - then put into an odd struct, you can manually convert `u0` - to match `p`: - - ```julia - function loss(p) - _prob = remake(prob, u0 = eltype(p).(prob.u0), p = MyStruct(p)) - sol = solve(_prob, ...) - # do stuff on sol - end - ``` - - Or you can define a dispatch on `DiffEqBase.anyeltypedual` - which tells the system what fields to interpret as the - differentiable parts. For example, to support ODESolutions - as parameters we tell it the data is `sol.u` and `sol.t` via: - - ```julia - function DiffEqBase.anyeltypedual(sol::ODESolution, counter = 0) - DiffEqBase.anyeltypedual((sol.u, sol.t)) - end - ``` - - To opt a type out of the dual checking, define an overload - that returns Any. For example: - - ```julia - function DiffEqBase.anyeltypedual(::YourType, ::Type{Val{counter}}) where {counter} - Any - end - ``` - - If you have defined this on a common type which should - be more generally supported, please open a pull request - adding this dispatch. If you need help defining this dispatch, - feel free to open an issue. - """ - -struct ForwardDiffAutomaticDetectionFailure <: Exception end - -function Base.showerror(io::IO, e::ForwardDiffAutomaticDetectionFailure) - print(io, FORWARDDIFF_AUTODETECTION_FAILURE_MESSAGE) -end - -function anyeltypedual(::Type{Union{}}) - throw(ForwardDiffAutomaticDetectionFailure()) -end - -function anyeltypedual(::Type{<:AbstractTimeseriesSolution{T, N}}, - ::Type{Val{counter}} = Val{0}) where {T, N, counter} - anyeltypedual(T) -end - -function anyeltypedual( - ::Type{T}, - ::Type{Val{counter}} = Val{0}) where {counter} where {T <: - NonlinearProblem{ - uType, iip, pType}} where {uType, iip, pType} - return anyeltypedual((uType, pType), Val{counter}) -end - -function anyeltypedual( - ::Type{T}, - ::Type{Val{counter}} = Val{0}) where {counter} where {T <: - NonlinearLeastSquaresProblem{ - uType, iip, pType}} where {uType, iip, pType} - return anyeltypedual((uType, pType), Val{counter}) -end - -function anyeltypedual(x::SciMLBase.RecipesBase.AbstractPlot, - ::Type{Val{counter}} = Val{0}) where {counter} - Any -end -function anyeltypedual(x::Returns, ::Type{Val{counter}} = Val{0}) where {counter} - anyeltypedual(x.value, Val{counter}) -end - -Base.@assume_effects :foldable function __anyeltypedual(::Type{T}) where {T} - if T isa Union - promote_dual(anyeltypedual(T.a), anyeltypedual(T.b)) - elseif hasproperty(T, :parameters) - mapreduce(anyeltypedual, promote_dual, T.parameters; init = Any) - else - T - end -end -function anyeltypedual(::Type{T}, ::Type{Val{counter}} = Val{0}) where {counter} where {T} - __anyeltypedual(T) -end - -function anyeltypedual(::Type{T}, - ::Type{Val{counter}} = Val{0}) where {counter} where {T <: - Union{AbstractArray, Set}} - anyeltypedual(eltype(T)) -end -Base.@pure function __anyeltypedual_ntuple(::Type{T}) where {T <: NTuple} - if isconcretetype(eltype(T)) - return eltype(T) - end - if isempty(T.parameters) - Any - else - mapreduce(anyeltypedual, promote_dual, T.parameters; init = Any) - end -end -function anyeltypedual( - ::Type{T}, ::Type{Val{counter}} = Val{0}) where {counter} where {T <: NTuple} - __anyeltypedual_ntuple(T) -end - -# Any in this context just means not Dual -function anyeltypedual( - x::SciMLBase.NullParameters, ::Type{Val{counter}} = Val{0}) where {counter} - Any -end - -function anyeltypedual(sol::RecursiveArrayTools.AbstractDiffEqArray, counter = 0) - diffeqmapreduce(anyeltypedual, promote_dual, (sol.u, sol.t)) -end - -function anyeltypedual(prob::Union{ODEProblem, SDEProblem, RODEProblem, DDEProblem}, - ::Type{Val{counter}} = Val{0}) where {counter} - anyeltypedual((prob.u0, prob.p, prob.tspan)) -end - -function anyeltypedual( - prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem, OptimizationProblem}, - ::Type{Val{counter}} = Val{0}) where {counter} - anyeltypedual((prob.u0, prob.p)) -end - -function anyeltypedual(x::Number, ::Type{Val{counter}} = Val{0}) where {counter} - anyeltypedual(typeof(x)) -end -function anyeltypedual( - x::Union{String, Symbol}, ::Type{Val{counter}} = Val{0}) where {counter} - typeof(x) -end -function anyeltypedual(x::Union{AbstractArray{T}, Set{T}}, - ::Type{Val{counter}} = Val{0}) where {counter} where { - T <: - Union{Number, - Symbol, - String}} - anyeltypedual(T) -end -function anyeltypedual(x::Union{AbstractArray{T}, Set{T}}, - ::Type{Val{counter}} = Val{0}) where {counter} where { - T <: Union{ - AbstractArray{ - <:Number, - }, - Set{ - <:Number, - }}} - anyeltypedual(eltype(x)) -end -function anyeltypedual(x::Union{AbstractArray{T}, Set{T}}, - ::Type{Val{counter}} = Val{0}) where {counter} where {N, T <: NTuple{N, <:Number}} - anyeltypedual(eltype(x)) -end - -# Try to avoid this dispatch because it can lead to type inference issues when !isconcrete(eltype(x)) -function anyeltypedual(x::AbstractArray, ::Type{Val{counter}} = Val{0}) where {counter} - if isconcretetype(eltype(x)) - anyeltypedual(eltype(x)) - elseif !isempty(x) && all(i -> isassigned(x, i), 1:length(x)) && - counter < DUALCHECK_RECURSION_MAX - _counter = Val{counter + 1} - mapreduce(y -> anyeltypedual(y, _counter), promote_dual, x) - else - # This fallback to Any is required since otherwise we cannot handle `undef` in all cases - # misses cases of - Any - end -end - -function anyeltypedual(x::Set, ::Type{Val{counter}} = Val{0}) where {counter} - if isconcretetype(eltype(x)) - anyeltypedual(eltype(x)) - else - # This fallback to Any is required since otherwise we cannot handle `undef` in all cases - Any - end -end - -function anyeltypedual(x::Tuple, ::Type{Val{counter}} = Val{0}) where {counter} - # Handle the empty tuple case separately for inference and to avoid mapreduce error - if x === () - Any - else - diffeqmapreduce(anyeltypedual, promote_dual, x) - end -end -function anyeltypedual(x::AbstractDict, ::Type{Val{counter}} = Val{0}) where {counter} - isempty(x) ? eltype(values(x)) : mapreduce(anyeltypedual, promote_dual, values(x)) -end -function anyeltypedual(x::NamedTuple, ::Type{Val{counter}} = Val{0}) where {counter} - anyeltypedual(values(x)) -end - -function anyeltypedual( - f::SciMLBase.AbstractSciMLFunction, ::Type{Val{counter}}) where {counter} - Any -end - -anyeltypedual(::@Kwargs{}, ::Type{Val{counter}} = Val{0}) where {counter} = Any -anyeltypedual(::Type{@Kwargs{}}, ::Type{Val{counter}} = Val{0}) where {counter} = Any - -# Opt out since these are using for preallocation, not differentiation -function anyeltypedual(x::Union{ForwardDiff.AbstractConfig, Module}, - ::Type{Val{counter}} = Val{0}) where {counter} - Any -end -function anyeltypedual(x::Type{T}, - ::Type{Val{counter}} = Val{0}) where {counter} where {T <: - ForwardDiff.AbstractConfig} - Any -end - -function anyeltypedual(x::ForwardDiff.DiffResults.DiffResult, - ::Type{Val{counter}} = Val{0}) where {counter} - Any -end -function anyeltypedual(x::Type{T}, - ::Type{Val{counter}} = Val{0}) where {counter} where {T <: - ForwardDiff.DiffResults.DiffResult} - Any -end - -function anyeltypedual(::Type{T}, - ::Type{Val{counter}} = Val{0}) where {counter} where {T <: ForwardDiff.Dual} - T -end - -function promote_tspan(u0::AbstractArray{<:ForwardDiff.Dual}, p, tspan, prob, kwargs) - if (haskey(kwargs, :callback) && has_continuous_callback(kwargs[:callback])) || - (haskey(prob.kwargs, :callback) && has_continuous_callback(prob.kwargs[:callback])) - return _promote_tspan(eltype(u0).(tspan), kwargs) +# Copy of the other prob2dtmin dispatch, just for optionality +function prob2dtmin(tspan, ::ForwardDiff.Dual, use_end_time) + t1, t2 = tspan + isfinite(t1) || throw(ArgumentError("t0 in the tspan `(t0, t1)` must be finite")) + if use_end_time && isfinite(t2 - t1) + return max(eps(t2), eps(t1)) else - return _promote_tspan(tspan, kwargs) + return max(eps(typeof(t1)), eps(t1)) end end -function promote_tspan(u0::AbstractArray{<:Complex{<:ForwardDiff.Dual}}, p, tspan, prob, - kwargs) - return _promote_tspan(real(eltype(u0)).(tspan), kwargs) -end - -function promote_tspan(u0::AbstractArray{<:ForwardDiff.Dual}, p, - tspan::Tuple{<:ForwardDiff.Dual, <:ForwardDiff.Dual}, prob, kwargs) - return _promote_tspan(tspan, kwargs) -end - -value(x::Type{ForwardDiff.Dual{T, V, N}}) where {T, V, N} = V -value(x::ForwardDiff.Dual) = value(ForwardDiff.value(x)) - -unitfulvalue(x::Type{ForwardDiff.Dual{T, V, N}}) where {T, V, N} = V -unitfulvalue(x::ForwardDiff.Dual) = unitfulvalue(ForwardDiff.value(x)) - -sse(x::ForwardDiff.Dual) = sse(ForwardDiff.value(x)) + sum(sse, ForwardDiff.partials(x)) -function DiffEqBase.totallength(x::ForwardDiff.Dual) - return DiffEqBase.totallength(ForwardDiff.value(x)) + - sum(DiffEqBase.totallength, ForwardDiff.partials(x)) +function hasdualpromote(u0, t::Number) + hasmethod(ArrayInterface.promote_eltype, + Tuple{Type{typeof(u0)}, Type{dualgen(eltype(u0))}}) && + hasmethod(promote_rule, + Tuple{Type{eltype(u0)}, Type{dualgen(eltype(u0))}}) && + hasmethod(promote_rule, + Tuple{Type{eltype(u0)}, Type{typeof(t)}}) end @inline ODE_DEFAULT_NORM(u::ForwardDiff.Dual, ::Any) = sqrt(sse(u)) diff --git a/ext/DiffEqBaseGTPSAExt.jl b/ext/DiffEqBaseGTPSAExt.jl index 10999bb26..130efde3a 100644 --- a/ext/DiffEqBaseGTPSAExt.jl +++ b/ext/DiffEqBaseGTPSAExt.jl @@ -1,7 +1,8 @@ module DiffEqBaseGTPSAExt using DiffEqBase -import DiffEqBase: value, ODE_DEFAULT_NORM +import DiffEqBase: ODE_DEFAULT_NORM +import SciMLBase: value, unitfulvalue using GTPSA value(x::TPS) = scalar(x) diff --git a/ext/DiffEqBaseMeasurementsExt.jl b/ext/DiffEqBaseMeasurementsExt.jl index a72423708..584e5242b 100644 --- a/ext/DiffEqBaseMeasurementsExt.jl +++ b/ext/DiffEqBaseMeasurementsExt.jl @@ -4,18 +4,6 @@ using DiffEqBase import DiffEqBase: value using Measurements -function DiffEqBase.promote_u0(u0::AbstractArray{<:Measurements.Measurement}, - p::AbstractArray{<:Measurements.Measurement}, t0) - u0 -end -DiffEqBase.promote_u0(u0, p::AbstractArray{<:Measurements.Measurement}, t0) = eltype(p).(u0) - -value(x::Type{Measurements.Measurement{T}}) where {T} = T -value(x::Measurements.Measurement) = Measurements.value(x) - -unitfulvalue(x::Type{Measurements.Measurement{T}}) where {T} = T -unitfulvalue(x::Measurements.Measurement) = Measurements.value(x) - # Support adaptive steps should be errorless @inline function DiffEqBase.ODE_DEFAULT_NORM( u::AbstractArray{ diff --git a/ext/DiffEqBaseMonteCarloMeasurementsExt.jl b/ext/DiffEqBaseMonteCarloMeasurementsExt.jl index d2b33d1a6..d3335a491 100644 --- a/ext/DiffEqBaseMonteCarloMeasurementsExt.jl +++ b/ext/DiffEqBaseMonteCarloMeasurementsExt.jl @@ -4,34 +4,6 @@ using DiffEqBase import DiffEqBase: value using MonteCarloMeasurements -function DiffEqBase.promote_u0( - u0::AbstractArray{ - <:MonteCarloMeasurements.AbstractParticles, - }, - p::AbstractArray{<:MonteCarloMeasurements.AbstractParticles}, - t0) - u0 -end -function DiffEqBase.promote_u0(u0, - p::AbstractArray{<:MonteCarloMeasurements.AbstractParticles}, - t0) - eltype(p).(u0) -end - -function DiffEqBase.promote_u0(::Nothing, - p::AbstractArray{<:MonteCarloMeasurements.AbstractParticles}, - t0) - return nothing -end - -DiffEqBase.value(x::Type{MonteCarloMeasurements.AbstractParticles{T, N}}) where {T, N} = T -DiffEqBase.value(x::MonteCarloMeasurements.AbstractParticles) = mean(x.particles) -function DiffEqBase.unitfulvalue(x::Type{MonteCarloMeasurements.AbstractParticles{ - T, N}}) where {T, N} - T -end -DiffEqBase.unitfulvalue(x::MonteCarloMeasurements.AbstractParticles) = mean(x.particles) - # Support adaptive steps should be errorless @inline function DiffEqBase.ODE_DEFAULT_NORM( u::AbstractArray{ diff --git a/ext/DiffEqBaseReverseDiffExt.jl b/ext/DiffEqBaseReverseDiffExt.jl index 879989d30..a69c5715d 100644 --- a/ext/DiffEqBaseReverseDiffExt.jl +++ b/ext/DiffEqBaseReverseDiffExt.jl @@ -5,57 +5,6 @@ import DiffEqBase: value import ReverseDiff import DiffEqBase.ArrayInterface -function DiffEqBase.anyeltypedual(::Type{T}, - ::Type{Val{counter}} = Val{0}) where {counter} where { - V, D, N, VA, DA, T <: ReverseDiff.TrackedArray{V, D, N, VA, DA}} - DiffEqBase.anyeltypedual(V, Val{counter}) -end - -DiffEqBase.value(x::Type{ReverseDiff.TrackedReal{V, D, O}}) where {V, D, O} = V -function DiffEqBase.value(x::Type{ - ReverseDiff.TrackedArray{V, D, N, VA, DA}, -}) where {V, D, - N, VA, - DA} - Array{V, N} -end -DiffEqBase.value(x::ReverseDiff.TrackedReal) = x.value -DiffEqBase.value(x::ReverseDiff.TrackedArray) = x.value - -DiffEqBase.unitfulvalue(x::Type{ReverseDiff.TrackedReal{V, D, O}}) where {V, D, O} = V -function DiffEqBase.unitfulvalue(x::Type{ - ReverseDiff.TrackedArray{V, D, N, VA, DA}, -}) where {V, D, - N, VA, - DA} - Array{V, N} -end -DiffEqBase.unitfulvalue(x::ReverseDiff.TrackedReal) = x.value -DiffEqBase.unitfulvalue(x::ReverseDiff.TrackedArray) = x.value - -# Force TrackedArray from TrackedReal when reshaping W\b -DiffEqBase._reshape(v::AbstractVector{<:ReverseDiff.TrackedReal}, siz) = reduce(vcat, v) - -DiffEqBase.promote_u0(u0::ReverseDiff.TrackedArray, p::ReverseDiff.TrackedArray, t0) = u0 -function DiffEqBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal}, - p::ReverseDiff.TrackedArray, t0) - u0 -end -function DiffEqBase.promote_u0(u0::ReverseDiff.TrackedArray, - p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) - u0 -end -function DiffEqBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal}, - p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) - u0 -end -DiffEqBase.promote_u0(u0, p::ReverseDiff.TrackedArray, t0) = ReverseDiff.track(u0) -function DiffEqBase.promote_u0( - u0, p::ReverseDiff.TrackedArray{T}, t0) where {T <: ReverseDiff.ForwardDiff.Dual} - ReverseDiff.track(T.(u0)) -end -DiffEqBase.promote_u0(u0, p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) = eltype(p).(u0) - # Support adaptive with non-tracked time @inline function DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedArray, t) sqrt(sum(abs2, DiffEqBase.value(u)) / length(u)) diff --git a/ext/DiffEqBaseTrackerExt.jl b/ext/DiffEqBaseTrackerExt.jl index 72c869fb8..6342cb86b 100644 --- a/ext/DiffEqBaseTrackerExt.jl +++ b/ext/DiffEqBaseTrackerExt.jl @@ -4,36 +4,6 @@ using DiffEqBase import DiffEqBase: value import Tracker -DiffEqBase.value(x::Type{Tracker.TrackedReal{T}}) where {T} = T -DiffEqBase.value(x::Type{Tracker.TrackedArray{T, N, A}}) where {T, N, A} = Array{T, N} -DiffEqBase.value(x::Tracker.TrackedReal) = x.data -DiffEqBase.value(x::Tracker.TrackedArray) = x.data - -DiffEqBase.unitfulvalue(x::Type{Tracker.TrackedReal{T}}) where {T} = T -function DiffEqBase.unitfulvalue(x::Type{Tracker.TrackedArray{T, N, A}}) where {T, N, A} - Array{T, N} -end -DiffEqBase.unitfulvalue(x::Tracker.TrackedReal) = x.data -DiffEqBase.unitfulvalue(x::Tracker.TrackedArray) = x.data - -DiffEqBase.promote_u0(u0::Tracker.TrackedArray, p::Tracker.TrackedArray, t0) = u0 -function DiffEqBase.promote_u0(u0::AbstractArray{<:Tracker.TrackedReal}, - p::Tracker.TrackedArray, t0) - u0 -end -function DiffEqBase.promote_u0(u0::Tracker.TrackedArray, - p::AbstractArray{<:Tracker.TrackedReal}, t0) - u0 -end -function DiffEqBase.promote_u0(u0::AbstractArray{<:Tracker.TrackedReal}, - p::AbstractArray{<:Tracker.TrackedReal}, t0) - u0 -end -DiffEqBase.promote_u0(u0, p::Tracker.TrackedArray, t0) = Tracker.track(u0) -DiffEqBase.promote_u0(u0, p::AbstractArray{<:Tracker.TrackedReal}, t0) = eltype(p).(u0) - -@inline Base.any(f::Function, x::Tracker.TrackedArray) = any(f, Tracker.data(x)) - # Support adaptive with non-tracked time @inline function DiffEqBase.ODE_DEFAULT_NORM(u::Tracker.TrackedArray, t) sqrt(sum(abs2, DiffEqBase.value(u)) / length(u)) diff --git a/ext/DiffEqBaseUnitfulExt.jl b/ext/DiffEqBaseUnitfulExt.jl index 381b0e516..85dd93f72 100644 --- a/ext/DiffEqBaseUnitfulExt.jl +++ b/ext/DiffEqBaseUnitfulExt.jl @@ -1,7 +1,7 @@ module DiffEqBaseUnitfulExt using DiffEqBase -import DiffEqBase: value +import SciMLBase: unitfulvalue, value using Unitful # Support adaptive errors should be errorless for exponentiation From 25f3a597c363d1121dbf83a20037fa9b24647592 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 21 Aug 2025 22:47:30 -0400 Subject: [PATCH 05/17] bump compat for SciMLBase --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 5997760af..d00bd380f 100644 --- a/Project.toml +++ b/Project.toml @@ -94,7 +94,7 @@ Printf = "1.9" RecursiveArrayTools = "3.1" Reexport = "1.0" ReverseDiff = "1" -SciMLBase = "2.94.0" +SciMLBase = "2.112.0" SciMLOperators = "1" SciMLStructures = "1.5" Setfield = "1" From 809d3dbe9fdf87be7a8ddc860b1e1bbda51c3278 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 21 Aug 2025 23:18:50 -0400 Subject: [PATCH 06/17] import errors from SciMLBase --- src/DiffEqBase.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/DiffEqBase.jl b/src/DiffEqBase.jl index 618ea35cf..a776b8030 100644 --- a/src/DiffEqBase.jl +++ b/src/DiffEqBase.jl @@ -98,7 +98,12 @@ import SciMLBase: solve, init, step!, solve!, __init, __solve, update_coefficien extract_alg, checkkwargs, has_kwargs, eltypedual, get_updated_symbolic_problem, get_concrete_p, get_concrete_u0, promote_u0, isconcreteu0, isconcretedu0, get_concrete_du0, _reshape, value, unitfulvalue, anyeltypedual, allowedkeywords, - sse, totallength, __sum, DualEltypeChecker + sse, totallength, __sum, DualEltypeChecker, KeywordArgError, KeywordArgWarn, KeywordArgSilent, KWARGWARN_MESSAGE, KWARGERROR_MESSAGE, + CommonKwargError, IncompatibleInitialConditionError, NO_DEFAULT_ALGORITHM_MESSAGE, NoDefaultAlgorithmError, NO_TSPAN_MESSAGE, NoTspanError, + NAN_TSPAN_MESSAGE, NaNTspanError, NON_SOLVER_MESSAGE, NonSolverError, NOISE_SIZE_MESSAGE, NoiseSizeIncompatabilityError, PROBSOLVER_PAIRING_MESSAGE, + ProblemSolverPairingError, compatible_problem_types, DIRECT_AUTODIFF_INCOMPATABILITY_MESSAGE, DirectAutodiffError, NONNUMBER_ELTYPE_MESSAGE, NonNumberEltypeError, + GENERIC_NUMBER_TYPE_ERROR_MESSAGE, GenericNumberTypeError, COMPLEX_SUPPORT_ERROR_MESSAGE, ComplexSupportError, COMPLEX_TSPAN_ERROR_MESSAGE, ComplexTspanError, + TUPLE_STATE_ERROR_MESSAGE, TupleStateError, MASS_MATRIX_ERROR_MESSAGE, IncompatibleMassMatrixError, LATE_BINDING_TSTOPS_ERROR_MESSAGE, LateBindingTstopsNotSupportedError import SciMLStructures From 0d62fb69ec717fbf0f4501e1447a782ff65aed10 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 21 Aug 2025 23:41:56 -0400 Subject: [PATCH 07/17] import NonConcreteEltypeError --- src/DiffEqBase.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/DiffEqBase.jl b/src/DiffEqBase.jl index a776b8030..cc8f125b6 100644 --- a/src/DiffEqBase.jl +++ b/src/DiffEqBase.jl @@ -103,7 +103,8 @@ import SciMLBase: solve, init, step!, solve!, __init, __solve, update_coefficien NAN_TSPAN_MESSAGE, NaNTspanError, NON_SOLVER_MESSAGE, NonSolverError, NOISE_SIZE_MESSAGE, NoiseSizeIncompatabilityError, PROBSOLVER_PAIRING_MESSAGE, ProblemSolverPairingError, compatible_problem_types, DIRECT_AUTODIFF_INCOMPATABILITY_MESSAGE, DirectAutodiffError, NONNUMBER_ELTYPE_MESSAGE, NonNumberEltypeError, GENERIC_NUMBER_TYPE_ERROR_MESSAGE, GenericNumberTypeError, COMPLEX_SUPPORT_ERROR_MESSAGE, ComplexSupportError, COMPLEX_TSPAN_ERROR_MESSAGE, ComplexTspanError, - TUPLE_STATE_ERROR_MESSAGE, TupleStateError, MASS_MATRIX_ERROR_MESSAGE, IncompatibleMassMatrixError, LATE_BINDING_TSTOPS_ERROR_MESSAGE, LateBindingTstopsNotSupportedError + TUPLE_STATE_ERROR_MESSAGE, TupleStateError, MASS_MATRIX_ERROR_MESSAGE, IncompatibleMassMatrixError, LATE_BINDING_TSTOPS_ERROR_MESSAGE, LateBindingTstopsNotSupportedError, + NONCONCRETE_ELTYPE_MESSAGE, NonConcreteEltypeError import SciMLStructures From 6cb37fbbd63f8cac9f3823d79ea3706f094aef44 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Fri, 22 Aug 2025 15:09:15 -0400 Subject: [PATCH 08/17] get rid of comments --- src/solve.jl | 657 --------------------------------------------------- 1 file changed, 657 deletions(-) diff --git a/src/solve.jl b/src/solve.jl index 092b06f59..a3143c4a7 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -7,536 +7,6 @@ NO_TSPAN_PROBS = Union{AbstractLinearProblem, AbstractNonlinearProblem, AbstractIntegralProblem, AbstractSteadyStateProblem, AbstractJumpProblem} -# has_kwargs(_prob::AbstractDEProblem) = has_kwargs(typeof(_prob)) -# Base.@pure __has_kwargs(::Type{T}) where {T} = :kwargs ∈ fieldnames(T) -# has_kwargs(::Type{T}) where {T} = __has_kwargs(T) - -# const allowedkeywords = (:dense, -# :saveat, -# :save_idxs, -# :tstops, -# :tspan, -# :d_discontinuities, -# :save_everystep, -# :save_on, -# :save_start, -# :save_end, -# :initialize_save, -# :adaptive, -# :abstol, -# :reltol, -# :dt, -# :dtmax, -# :dtmin, -# :force_dtmin, -# :internalnorm, -# :controller, -# :gamma, -# :beta1, -# :beta2, -# :qmax, -# :qmin, -# :qsteady_min, -# :qsteady_max, -# :qoldinit, -# :failfactor, -# :calck, -# :alias_u0, -# :maxiters, -# :maxtime, -# :callback, -# :isoutofdomain, -# :unstable_check, -# :verbose, -# :merge_callbacks, -# :progress, -# :progress_steps, -# :progress_name, -# :progress_message, -# :progress_id, -# :timeseries_errors, -# :dense_errors, -# :weak_timeseries_errors, -# :weak_dense_errors, -# :wrap, -# :calculate_error, -# :initializealg, -# :alg, -# :save_noise, -# :delta, -# :seed, -# :alg_hints, -# :kwargshandle, -# :trajectories, -# :batch_size, -# :sensealg, -# :advance_to_tstop, -# :stop_at_next_tstop, -# :u0, -# :p, -# # These two are from the default algorithm handling -# :default_set, -# :second_time, -# # This is for DiffEqDevTools -# :prob_choice, -# # Jump problems -# :alias_jump, -# # This is for copying/deepcopying noise in StochasticDiffEq -# :alias_noise, -# # This is for SimpleNonlinearSolve handling for batched Nonlinear Solves -# :batch, -# # Shooting method in BVP needs to differentiate between these two categories -# :nlsolve_kwargs, -# :odesolve_kwargs, -# # If Solvers which internally use linsolve -# :linsolve_kwargs, -# # Solvers internally using EnsembleProblem -# :ensemblealg, -# # Fine Grained Control of Tracing (Storing and Logging) during Solve -# :show_trace, -# :trace_level, -# :store_trace, -# # Termination condition for solvers -# :termination_condition, -# # For AbstractAliasSpecifier -# :alias, -# # Parameter estimation with BVP -# :fit_parameters) - -# const KWARGWARN_MESSAGE = """ -# Unrecognized keyword arguments found. -# The only allowed keyword arguments to `solve` are: -# $allowedkeywords - -# See https://diffeq.sciml.ai/stable/basics/common_solver_opts/ for more details. - -# Set kwargshandle=KeywordArgError for an error message. -# Set kwargshandle=KeywordArgSilent to ignore this message. -# """ - -# const KWARGERROR_MESSAGE = """ -# Unrecognized keyword arguments found. -# The only allowed keyword arguments to `solve` are: -# $allowedkeywords - -# See https://diffeq.sciml.ai/stable/basics/common_solver_opts/ for more details. -# """ - -# struct CommonKwargError <: Exception -# kwargs::Any -# end - -# function Base.showerror(io::IO, e::CommonKwargError) -# println(io, KWARGERROR_MESSAGE) -# notin = collect(map(x -> x ∉ allowedkeywords, keys(e.kwargs))) -# unrecognized = collect(keys(e.kwargs))[notin] -# print(io, "Unrecognized keyword arguments: ") -# printstyled(io, unrecognized; bold = true, color = :red) -# print(io, "\n\n") -# println(io, TruncatedStacktraces.VERBOSE_MSG) -# end - -# @enum KeywordArgError KeywordArgWarn KeywordArgSilent - -# const INCOMPATIBLE_U0_MESSAGE = """ -# Initial condition incompatible with functional form. -# Detected an in-place function with an initial condition of type Number or SArray. -# This is incompatible because Numbers cannot be mutated, i.e. -# `x = 2.0; y = 2.0; x .= y` will error. - -# If using a immutable initial condition type, please use the out-of-place form. -# I.e. define the function `du=f(u,p,t)` instead of attempting to "mutate" the immutable `du`. - -# If your differential equation function was defined with multiple dispatches and one is -# in-place, then the automatic detection will choose in-place. In this case, override the -# choice in the problem constructor, i.e. `ODEProblem{false}(f,u0,tspan,p,kwargs...)`. - -# For a longer discussion on mutability vs immutability and in-place vs out-of-place, see: -# https://diffeq.sciml.ai/stable/tutorials/faster_ode_example/#Example-Accelerating-a-Non-Stiff-Equation:-The-Lorenz-Equation -# """ - -# struct IncompatibleInitialConditionError <: Exception end - -# function Base.showerror(io::IO, e::IncompatibleInitialConditionError) -# print(io, INCOMPATIBLE_U0_MESSAGE) -# println(io, TruncatedStacktraces.VERBOSE_MSG) -# end - -# const NO_DEFAULT_ALGORITHM_MESSAGE = """ -# Default algorithm choices require DifferentialEquations.jl. -# Please specify an algorithm (e.g., `solve(prob, Tsit5())` or -# `init(prob, Tsit5())` for an ODE) or import DifferentialEquations -# directly. - -# You can find the list of available solvers at https://diffeq.sciml.ai/stable/solvers/ode_solve/ -# and its associated pages. -# """ - -# struct NoDefaultAlgorithmError <: Exception end - -# function Base.showerror(io::IO, e::NoDefaultAlgorithmError) -# print(io, NO_DEFAULT_ALGORITHM_MESSAGE) -# println(io, TruncatedStacktraces.VERBOSE_MSG) -# end - -# const NO_TSPAN_MESSAGE = """ -# No tspan is set in the problem or chosen in the init/solve call -# """ - -# struct NoTspanError <: Exception end - -# function Base.showerror(io::IO, e::NoTspanError) -# print(io, NO_TSPAN_MESSAGE) -# println(io, TruncatedStacktraces.VERBOSE_MSG) -# end - -# const NAN_TSPAN_MESSAGE = """ -# NaN tspan is set in the problem or chosen in the init/solve call. -# Note that -Inf and Inf values are allowed in the timespan for solves -# which are terminated via callbacks, however NaN values are not allowed -# since the direction of time is undetermined. -# """ - -# struct NaNTspanError <: Exception end - -# function Base.showerror(io::IO, e::NaNTspanError) -# print(io, NAN_TSPAN_MESSAGE) -# println(io, TruncatedStacktraces.VERBOSE_MSG) -# end - -# const NON_SOLVER_MESSAGE = """ -# The arguments to solve are incorrect. -# The second argument must be a solver choice, `solve(prob,alg)` -# where `alg` is a `<: AbstractDEAlgorithm`, e.g. `Tsit5()`. - -# Please double check the arguments being sent to the solver. - -# You can find the list of available solvers at https://diffeq.sciml.ai/stable/solvers/ode_solve/ -# and its associated pages. -# """ - -# struct NonSolverError <: Exception end - -# function Base.showerror(io::IO, e::NonSolverError) -# print(io, NON_SOLVER_MESSAGE) -# println(io, TruncatedStacktraces.VERBOSE_MSG) -# end - -# const NOISE_SIZE_MESSAGE = """ -# Noise sizes are incompatible. The expected number of noise terms in the defined -# `noise_rate_prototype` does not match the number of noise terms in the defined -# `AbstractNoiseProcess`. Please ensure that -# size(prob.noise_rate_prototype,2) == length(prob.noise.W[1]). - -# Note: Noise process definitions require that users specify `u0`, and this value is -# directly used in the definition. For example, if `noise = WienerProcess(0.0,0.0)`, -# then the noise process is a scalar with `u0=0.0`. If `noise = WienerProcess(0.0,[0.0])`, -# then the noise process is a vector with `u0=0.0`. If `noise_rate_prototype = zeros(2,4)`, -# then the noise process must be a 4-dimensional process, for example -# `noise = WienerProcess(0.0,zeros(4))`. This error is a sign that the user definition -# of `noise_rate_prototype` and `noise` are not aligned in this manner and the definitions should -# be double checked. -# """ - -# struct NoiseSizeIncompatabilityError <: Exception -# prototypesize::Int -# noisesize::Int -# end - -# function Base.showerror(io::IO, e::NoiseSizeIncompatabilityError) -# println(io, NOISE_SIZE_MESSAGE) -# println(io, "size(prob.noise_rate_prototype,2) = $(e.prototypesize)") -# println(io, "length(prob.noise.W[1]) = $(e.noisesize)") -# println(io, TruncatedStacktraces.VERBOSE_MSG) -# end - -# const PROBSOLVER_PAIRING_MESSAGE = """ -# Incompatible problem+solver pairing. -# For example, this can occur if an ODE solver is passed with an SDEProblem. -# Solvers are only capable of handling specific problem types. Please double -# check that the chosen pairing is capable for handling the given problems. -# """ - -# struct ProblemSolverPairingError <: Exception -# prob::Any -# alg::Any -# end - -# function Base.showerror(io::IO, e::ProblemSolverPairingError) -# println(io, PROBSOLVER_PAIRING_MESSAGE) -# println(io, "Problem type: $(SciMLBase.__parameterless_type(typeof(e.prob)))") -# println(io, "Solver type: $(SciMLBase.__parameterless_type(typeof(e.alg)))") -# println(io, -# "Problem types compatible with the chosen solver: $(compatible_problem_types(e.prob,e.alg))") -# println(io, TruncatedStacktraces.VERBOSE_MSG) -# end - -# function compatible_problem_types(prob, alg) -# if alg isa AbstractODEAlgorithm -# ODEProblem -# elseif alg isa AbstractSDEAlgorithm -# (SDEProblem, SDDEProblem) -# elseif alg isa AbstractDDEAlgorithm # StochasticDelayDiffEq.jl just uses the SDE alg -# DDEProblem -# elseif alg isa AbstractDAEAlgorithm -# DAEProblem -# elseif alg isa AbstractSteadyStateAlgorithm -# SteadyStateProblem -# end -# end - -# const DIRECT_AUTODIFF_INCOMPATABILITY_MESSAGE = """ -# Incompatible solver + automatic differentiation pairing. -# The chosen automatic differentiation algorithm requires the ability -# for compiler transforms on the code which is only possible on pure-Julia -# solvers such as those from OrdinaryDiffEq.jl. Direct differentiation methods -# which require this ability include: - -# - Direct use of ForwardDiff.jl on the solver -# - `ForwardDiffSensitivity`, `ReverseDiffAdjoint`, `TrackerAdjoint`, and `ZygoteAdjoint` -# sensealg choices for adjoint differentiation. - -# Either switch the choice of solver to a pure Julia method, or change the automatic -# differentiation method to one that does not require such transformations. - -# For more details on automatic differentiation, adjoint, and sensitivity analysis -# of differential equations, see the documentation page: - -# https://diffeq.sciml.ai/stable/analysis/sensitivity/ -# """ - -# struct DirectAutodiffError <: Exception end - -# function Base.showerror(io::IO, e::DirectAutodiffError) -# println(io, DIRECT_AUTODIFF_INCOMPATABILITY_MESSAGE) -# println(io, TruncatedStacktraces.VERBOSE_MSG) -# end - -# const NONCONCRETE_ELTYPE_MESSAGE = """ -# Non-concrete element type inside of an `Array` detected. -# Arrays with non-concrete element types, such as -# `Array{Union{Float32,Float64}}`, are not supported by the -# differential equation solvers. Anyways, this is bad for -# performance so you don't want to be doing this! - -# If this was a mistake, promote the element types to be -# all the same. If this was intentional, for example, -# using Unitful.jl with different unit values, then use -# an array type which has fast broadcast support for -# heterogeneous values such as the ArrayPartition -# from RecursiveArrayTools.jl. For example: - -# ```julia -# using RecursiveArrayTools -# x = ArrayPartition([1.0,2.0],[1f0,2f0]) -# y = ArrayPartition([3.0,4.0],[3f0,4f0]) -# x .+ y # fast, stable, and usable as u0 into DiffEq! -# ``` - -# Element type: -# """ - -# struct NonConcreteEltypeError <: Exception -# eltype::Any -# end - -# function Base.showerror(io::IO, e::NonConcreteEltypeError) -# print(io, NONCONCRETE_ELTYPE_MESSAGE) -# print(io, e.eltype) -# println(io, TruncatedStacktraces.VERBOSE_MSG) -# end - -# const NONNUMBER_ELTYPE_MESSAGE = """ -# Non-Number element type inside of an `Array` detected. -# Arrays with non-number element types, such as -# `Array{Array{Float64}}`, are not supported by the -# solvers. - -# If you are trying to use an array of arrays structure, -# look at the tools in RecursiveArrayTools.jl. For example: - -# If this was a mistake, promote the element types to be -# all the same. If this was intentional, for example, -# using Unitful.jl with different unit values, then use -# an array type which has fast broadcast support for -# heterogeneous values such as the ArrayPartition -# from RecursiveArrayTools.jl. For example: - -# ```julia -# using RecursiveArrayTools -# u0 = ArrayPartition([1.0,2.0],[3.0,4.0]) -# u0 = VectorOfArray([1.0,2.0],[3.0,4.0]) -# ``` - -# are both initial conditions which would be compatible with -# the solvers. Or use ComponentArrays.jl for more complex -# nested structures. - -# Element type: -# """ - -# struct NonNumberEltypeError <: Exception -# eltype::Any -# end - -# function Base.showerror(io::IO, e::NonNumberEltypeError) -# print(io, NONNUMBER_ELTYPE_MESSAGE) -# print(io, e.eltype) -# println(io, TruncatedStacktraces.VERBOSE_MSG) -# end - -# const GENERIC_NUMBER_TYPE_ERROR_MESSAGE = """ -# Non-standard number type (i.e. not Float32, Float64, -# ComplexF32, or ComplexF64) detected as the element type -# for the initial condition or time span. These generic -# number types are only compatible with the pure Julia -# solvers which support generic programming, such as -# OrdinaryDiffEq.jl. The chosen solver does not support -# this functionality. Please double check that the initial -# condition and time span types are correct, and check that -# the chosen solver was correct. -# """ - -# struct GenericNumberTypeError <: Exception -# alg::Any -# uType::Any -# tType::Any -# end - -# function Base.showerror(io::IO, e::GenericNumberTypeError) -# println(io, GENERIC_NUMBER_TYPE_ERROR_MESSAGE) -# println(io, "Solver: $(e.alg)") -# println(io, "u0 type: $(e.uType)") -# print(io, "Timespan type: $(e.tType)") -# println(io, TruncatedStacktraces.VERBOSE_MSG) -# end - -# const COMPLEX_SUPPORT_ERROR_MESSAGE = """ -# Complex number type (i.e. ComplexF32, or ComplexF64) -# detected as the element type for the initial condition -# with an algorithm that does not support complex numbers. -# Please check that the initial condition type is correct. -# If complex number support is needed, try different solvers -# such as those from OrdinaryDiffEq.jl. -# """ - -# struct ComplexSupportError <: Exception -# alg::Any -# end - -# function Base.showerror(io::IO, e::ComplexSupportError) -# println(io, COMPLEX_SUPPORT_ERROR_MESSAGE) -# println(io, "Solver: $(e.alg)") -# println(io, TruncatedStacktraces.VERBOSE_MSG) -# end - -# const COMPLEX_TSPAN_ERROR_MESSAGE = """ -# Complex number type (i.e. ComplexF32, or ComplexF64) -# detected as the element type for the independent variable -# (i.e. time span). Please check that the tspan type is correct. -# No solvers support complex time spans. If this is required, -# please open an issue. -# """ - -# struct ComplexTspanError <: Exception end - -# function Base.showerror(io::IO, e::ComplexTspanError) -# println(io, COMPLEX_TSPAN_ERROR_MESSAGE) -# println(io, TruncatedStacktraces.VERBOSE_MSG) -# end - -# const TUPLE_STATE_ERROR_MESSAGE = """ -# Tuple type used as a state. Since a tuple does not have vector -# properties, it will not work as a state type in equation solvers. -# Instead, change your equation from using tuple constructors `()` -# to static array constructors `SA[]`. For example, change: - -# ```julia -# function ftup((a,b),p,t) -# return b,-a -# end -# u0 = (1.0,2.0) -# tspan = (0.0,1.0) -# ODEProblem(ftup,u0,tspan) -# ``` - -# to: - -# ```julia -# using StaticArrays -# function fsa(u,p,t) -# SA[u[2],u[1]] -# end -# u0 = SA[1.0,2.0] -# tspan = (0.0,1.0) -# ODEProblem(ftup,u0,tspan) -# ``` - -# This will be safer and fast for small ODEs. For more information, see: -# https://diffeq.sciml.ai/stable/tutorials/faster_ode_example/#Further-Optimizations-of-Small-Non-Stiff-ODEs-with-StaticArrays -# """ - -# struct TupleStateError <: Exception end - -# function Base.showerror(io::IO, e::TupleStateError) -# println(io, TUPLE_STATE_ERROR_MESSAGE) -# println(io, TruncatedStacktraces.VERBOSE_MSG) -# end - -# const MASS_MATRIX_ERROR_MESSAGE = """ -# Mass matrix size is incompatible with initial condition -# sizing. The mass matrix must represent the `vec` -# form of the initial condition `u0`, i.e. -# `size(mm,1) == size(mm,2) == length(u)` -# """ - -# struct IncompatibleMassMatrixError <: Exception -# sz::Int -# len::Int -# end - -# function Base.showerror(io::IO, e::IncompatibleMassMatrixError) -# println(io, MASS_MATRIX_ERROR_MESSAGE) -# print(io, "size(prob.f.mass_matrix,1): ") -# println(io, e.sz) -# print(io, "length(u0): ") -# println(e.len) -# println(io, TruncatedStacktraces.VERBOSE_MSG) -# end - -# const LATE_BINDING_TSTOPS_ERROR_MESSAGE = """ -# This solver does not support providing `tstops` as a function. -# Consider using a different solver or providing `tstops` as an array -# of times. -# """ - -# struct LateBindingTstopsNotSupportedError <: Exception end - -# function Base.showerror(io::IO, e::LateBindingTstopsNotSupportedError) -# println(io, LATE_BINDING_TSTOPS_ERROR_MESSAGE) -# println(io, TruncatedStacktraces.VERBOSE_MSG) -# end - -# """ -# $(TYPEDSIGNATURES) - -# Given the index provider `indp` used to construct the problem `prob` being solved, return -# an updated `prob` to be used for solving. All implementations should accept arbitrary -# keyword arguments. - -# Should be called before the problem is solved, after performing type-promotion on the -# problem. If the returned problem is not `===` the provided `prob`, it is assumed to -# contain the `u0` and `p` passed as keyword arguments. - -# # Keyword Arguments - -# - `u0`, `p`: Override values for `state_values(prob)` and `parameter_values(prob)` which -# should be used instead of the ones in `prob`. -# """ -# function get_updated_symbolic_problem(indp, prob; kw...) -# return prob -# end - function init_call(_prob, args...; merge_callbacks = true, kwargshandle = nothing, kwargs...) kwargshandle = kwargshandle === nothing ? SciMLBase.KeywordArgError : kwargshandle @@ -1240,22 +710,6 @@ function solve(prob::AbstractJumpProblem, args...; kwargs...) __solve(prob, args...; kwargs...) end -# function checkkwargs(kwargshandle; kwargs...) -# if any(x -> x ∉ allowedkeywords, keys(kwargs)) -# if kwargshandle == KeywordArgError -# throw(CommonKwargError(kwargs)) -# elseif kwargshandle == KeywordArgWarn -# @warn KWARGWARN_MESSAGE -# unrecognized = setdiff(keys(kwargs), allowedkeywords) -# print("Unrecognized keyword arguments: ") -# printstyled(unrecognized; bold = true, color = :red) -# print("\n\n") -# else -# @assert kwargshandle == KeywordArgSilent -# end -# end -# end - function get_concrete_problem(prob::AbstractJumpProblem, isadapt; kwargs...) get_updated_symbolic_problem(SciMLBase.get_root_indp(prob), prob; kwargs...) end @@ -1441,100 +895,6 @@ function get_concrete_tspan(prob, isadapt, kwargs, p) tspan end -# function isconcreteu0(prob, t0, kwargs) -# !eval_u0(prob.u0) && prob.u0 !== nothing && !isdistribution(prob.u0) -# end - -# function isconcretedu0(prob, t0, kwargs) -# !eval_u0(prob.u0) && prob.du0 !== nothing && !isdistribution(prob.du0) -# end - -# function get_concrete_u0(prob, isadapt, t0, kwargs) -# if eval_u0(prob.u0) -# u0 = prob.u0(prob.p, t0) -# elseif haskey(kwargs, :u0) -# u0 = kwargs[:u0] -# else -# u0 = prob.u0 -# end - -# isadapt && eltype(u0) <: Integer && (u0 = float.(u0)) - -# _u0 = handle_distribution_u0(u0) - -# if isinplace(prob) && (_u0 isa Number || _u0 isa SArray) -# throw(IncompatibleInitialConditionError()) -# end - -# nu0 = length(something(_u0, ())) -# if isdefined(prob.f, :mass_matrix) && prob.f.mass_matrix !== nothing && -# prob.f.mass_matrix isa AbstractArray && -# size(prob.f.mass_matrix, 1) !== nu0 -# throw(IncompatibleMassMatrixError(size(prob.f.mass_matrix, 1), nu0)) -# end - -# if _u0 isa Tuple -# throw(TupleStateError()) -# end - -# _u0 -# end - -# function get_concrete_u0(prob::BVProblem, isadapt, t0, kwargs) -# if haskey(kwargs, :u0) -# u0 = kwargs[:u0] -# else -# u0 = prob.u0 -# end - -# isadapt && eltype(u0) <: Integer && (u0 = float.(u0)) - -# _u0 = handle_distribution_u0(u0) - -# if isinplace(prob) && (_u0 isa Number || _u0 isa SArray) -# throw(IncompatibleInitialConditionError()) -# end - -# if _u0 isa Tuple -# throw(TupleStateError()) -# end - -# return _u0 -# end - -# function get_concrete_du0(prob, isadapt, t0, kwargs) -# if eval_u0(prob.du0) -# du0 = prob.du0(prob.p, t0) -# elseif haskey(kwargs, :du0) -# du0 = kwargs[:du0] -# else -# du0 = prob.du0 -# end - -# isadapt && eltype(du0) <: Integer && (du0 = float.(du0)) - -# _du0 = handle_distribution_u0(du0) - -# if isinplace(prob) && (_du0 isa Number || _du0 isa SArray) -# throw(IncompatibleInitialConditionError()) -# end - -# _du0 -# end - -# function get_concrete_p(prob, kwargs) -# if haskey(kwargs, :p) -# p = kwargs[:p] -# else -# p = prob.p -# end -# end - -# handle_distribution_u0(_u0) = _u0 - -# eval_u0(u0::Function) = true -# eval_u0(u0) = false - function __solve( prob::AbstractDEProblem, args...; default_set = false, second_time = false, kwargs...) @@ -1626,23 +986,6 @@ function check_prob_alg_pairing(prob, alg) end end -# @inline function extract_alg(solve_args, solve_kwargs, prob_kwargs) -# if isempty(solve_args) || isnothing(first(solve_args)) -# if haskey(solve_kwargs, :alg) -# solve_kwargs[:alg] -# elseif haskey(prob_kwargs, :alg) -# prob_kwargs[:alg] -# else -# nothing -# end -# elseif first(solve_args) isa SciMLBase.AbstractSciMLAlgorithm && -# !(first(solve_args) isa SciMLBase.EnsembleAlgorithm) -# first(solve_args) -# else -# nothing -# end -# end - ################### Differentiation """ From 96cb04a6313d4d89506b9f80c00626d204b6661b Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 27 Aug 2025 13:56:37 -0400 Subject: [PATCH 09/17] remove _concrete_solve_adjoint and _concrete_solve_forward --- src/solve.jl | 36 ------------------------------------ 1 file changed, 36 deletions(-) diff --git a/src/solve.jl b/src/solve.jl index a3143c4a7..ec3f31c79 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -1069,39 +1069,3 @@ function _solve_forward(prob, sensealg, u0, p, originator, args...; merge_callba end end -#### -# Catch undefined AD overload cases - -const ADJOINT_NOT_FOUND_MESSAGE = """ - Compatibility with reverse-mode automatic differentiation requires SciMLSensitivity.jl. - Please install SciMLSensitivity.jl and do `using SciMLSensitivity`/`import SciMLSensitivity` - for this functionality. For more details, see https://sensitivity.sciml.ai/dev/. - """ - -struct AdjointNotFoundError <: Exception end - -function Base.showerror(io::IO, e::AdjointNotFoundError) - print(io, ADJOINT_NOT_FOUND_MESSAGE) - println(io, TruncatedStacktraces.VERBOSE_MSG) -end - -function _concrete_solve_adjoint(args...; kwargs...) - throw(AdjointNotFoundError()) -end - -const FORWARD_SENSITIVITY_NOT_FOUND_MESSAGE = """ - Compatibility with forward-mode automatic differentiation requires SciMLSensitivity.jl. - Please install SciMLSensitivity.jl and do `using SciMLSensitivity`/`import SciMLSensitivity` - for this functionality. For more details, see https://sensitivity.sciml.ai/dev/. - """ - -struct ForwardSensitivityNotFoundError <: Exception end - -function Base.showerror(io::IO, e::ForwardSensitivityNotFoundError) - print(io, FORWARD_SENSITIVITY_NOT_FOUND_MESSAGE) - println(io, TruncatedStacktraces.VERBOSE_MSG) -end - -function _concrete_solve_forward(args...; kwargs...) - throw(ForwardSensitivityNotFoundError()) -end From 2d2c533a529a421383933256b7ad154ef77e4668 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 27 Aug 2025 15:45:56 -0400 Subject: [PATCH 10/17] add imports for _concrete solves --- src/DiffEqBase.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DiffEqBase.jl b/src/DiffEqBase.jl index cc8f125b6..0d16ba27b 100644 --- a/src/DiffEqBase.jl +++ b/src/DiffEqBase.jl @@ -95,7 +95,7 @@ using SciMLBase: @def, DEIntegrator, AbstractDEProblem, import SciMLBase: solve, init, step!, solve!, __init, __solve, update_coefficients!, update_coefficients, isadaptive, wrapfun_oop, wrapfun_iip, unwrap_fw, promote_tspan, set_u!, set_t!, set_ut!, - extract_alg, checkkwargs, has_kwargs, + extract_alg, checkkwargs, has_kwargs, _concrete_solve_adjoint, _concrete_solve_forward, eltypedual, get_updated_symbolic_problem, get_concrete_p, get_concrete_u0, promote_u0, isconcreteu0, isconcretedu0, get_concrete_du0, _reshape, value, unitfulvalue, anyeltypedual, allowedkeywords, sse, totallength, __sum, DualEltypeChecker, KeywordArgError, KeywordArgWarn, KeywordArgSilent, KWARGWARN_MESSAGE, KWARGERROR_MESSAGE, From a1503998ef257f30d383500b6bdacf464ae5edbc Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 28 Aug 2025 10:19:04 -0400 Subject: [PATCH 11/17] remove redundant mooncake rules --- ext/DiffEqBaseMooncakeExt.jl | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/ext/DiffEqBaseMooncakeExt.jl b/ext/DiffEqBaseMooncakeExt.jl index 16e4b46e5..d0afcd110 100644 --- a/ext/DiffEqBaseMooncakeExt.jl +++ b/ext/DiffEqBaseMooncakeExt.jl @@ -29,18 +29,4 @@ import Mooncake: rrule!!, CoDual, zero_fcodual, @is_primitive, }, true,) -@zero_adjoint MinimalCtx Tuple{typeof(DiffEqBase.numargs), Any} -@is_primitive MinimalCtx Tuple{ - typeof(DiffEqBase.set_mooncakeoriginator_if_mooncake), SciMLBase.ChainRulesOriginator -} - -@mooncake_overlay DiffEqBase.set_mooncakeoriginator_if_mooncake(x::SciMLBase.ADOriginator) = SciMLBase.MooncakeOriginator() - -function rrule!!( - f::CoDual{typeof(DiffEqBase.set_mooncakeoriginator_if_mooncake)}, - X::CoDual{SciMLBase.ChainRulesOriginator} -) - return zero_fcodual(SciMLBase.MooncakeOriginator()), NoPullback(f, X) -end - end From ebcace992d87eafce1368ba67e272502e2ec6bf8 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 28 Aug 2025 10:19:25 -0400 Subject: [PATCH 12/17] bump compat for SciMLBase --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index d00bd380f..88e42a4ab 100644 --- a/Project.toml +++ b/Project.toml @@ -94,7 +94,7 @@ Printf = "1.9" RecursiveArrayTools = "3.1" Reexport = "1.0" ReverseDiff = "1" -SciMLBase = "2.112.0" +SciMLBase = "2.114.0" SciMLOperators = "1" SciMLStructures = "1.5" Setfield = "1" From 0da247f67d2e60e2a6b4c8d49da9dbf2447e9989 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 28 Aug 2025 10:21:20 -0400 Subject: [PATCH 13/17] remove solve dispatches for EnsembleProblems --- src/solve.jl | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/src/solve.jl b/src/solve.jl index ec3f31c79..7bf098995 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -691,17 +691,6 @@ function solve_call(prob::SteadyStateProblem, kwargs...) end -function solve(prob::EnsembleProblem, args...; kwargs...) - alg = extract_alg(args, kwargs, kwargs) - if length(args) > 1 - __solve(prob, alg, Base.tail(args)...; kwargs...) - else - __solve(prob, alg; kwargs...) - end -end -function solve(prob::SciMLBase.WeightedEnsembleProblem, args...; kwargs...) - SciMLBase.WeightedEnsembleSolution(solve(prob.ensembleprob), prob.weights) -end function solve(prob::AbstractNoiseProblem, args...; kwargs...) __solve(prob, args...; kwargs...) end From b2b5adba47547ef38c5944c43bc8d8037de63276 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 28 Aug 2025 10:59:18 -0400 Subject: [PATCH 14/17] format solve --- src/solve.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/solve.jl b/src/solve.jl index 7bf098995..512f1c689 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -189,7 +189,7 @@ function build_null_solution(prob::AbstractDEProblem, args...; save_everystep = true, save_on = true, save_start = save_everystep || isempty(saveat) || - saveat isa Number || prob.tspan[1] in saveat, + saveat isa Number || prob.tspan[1] in saveat, save_end = true, kwargs...) ts = if saveat === () @@ -222,7 +222,7 @@ function build_null_solution( save_everystep = true, save_on = true, save_start = save_everystep || isempty(saveat) || - saveat isa Number || prob.tspan[1] in saveat, + saveat isa Number || prob.tspan[1] in saveat, save_end = true, kwargs...) prob, success = hack_null_solution_init(prob) @@ -1057,4 +1057,3 @@ function _solve_forward(prob, sensealg, u0, p, originator, args...; merge_callba _concrete_solve_forward(_prob, alg, sensealg, u0, p, originator; kwargs...) end end - From 8a47130fa442696067c421962dca38a05a1daaed Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 28 Aug 2025 12:09:15 -0400 Subject: [PATCH 15/17] change errors to reference SciMLBase --- test/downstream/kwarg_warn.jl | 13 ++++++------ test/downstream/solve_error_handling.jl | 28 ++++++++++++------------- 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/test/downstream/kwarg_warn.jl b/test/downstream/kwarg_warn.jl index 32e77976a..1cdab9881 100644 --- a/test/downstream/kwarg_warn.jl +++ b/test/downstream/kwarg_warn.jl @@ -1,4 +1,5 @@ using OrdinaryDiffEq, Test +using DiffEqBase function lorenz(du, u, p, t) du[1] = 10.0(u[2] - u[1]) du[2] = u[1] * (28.0 - u[3]) - u[2] @@ -8,10 +9,10 @@ u0 = [1.0; 0.0; 0.0] tspan = (0.0, 100.0) prob = ODEProblem(lorenz, u0, tspan) @test_nowarn sol = solve(prob, Tsit5(), reltol = 1e-6) -sol = solve(prob, Tsit5(), rel_tol = 1e-6, kwargshandle = DiffEqBase.KeywordArgWarn) -@test_logs (:warn, DiffEqBase.KWARGWARN_MESSAGE) sol=solve( - prob, Tsit5(), rel_tol = 1e-6, kwargshandle = DiffEqBase.KeywordArgWarn) -@test_throws DiffEqBase.CommonKwargError sol=solve(prob, Tsit5(), rel_tol = 1e-6) +sol = solve(prob, Tsit5(), rel_tol = 1e-6, kwargshandle = SciMLBase.KeywordArgWarn) +@test_logs (:warn, SciMLBase.KWARGWARN_MESSAGE) sol=solve( + prob, Tsit5(), rel_tol = 1e-6, kwargshandle = SciMLBase.KeywordArgWarn) +@test_throws SciMLBase.CommonKwargError sol=solve(prob, Tsit5(), rel_tol = 1e-6) -prob = ODEProblem(lorenz, u0, tspan, test = 2.0, kwargshandle = DiffEqBase.KeywordArgWarn) -@test_logs (:warn, DiffEqBase.KWARGWARN_MESSAGE) sol=solve(prob, Tsit5(), reltol = 1e-6) +prob = ODEProblem(lorenz, u0, tspan, test = 2.0, kwargshandle = SciMLBase.KeywordArgWarn) +@test_logs (:warn, SciMLsBase.KWARGWARN_MESSAGE) sol=solve(prob, Tsit5(), reltol = 1e-6) diff --git a/test/downstream/solve_error_handling.jl b/test/downstream/solve_error_handling.jl index dae035e14..eab1fadf4 100644 --- a/test/downstream/solve_error_handling.jl +++ b/test/downstream/solve_error_handling.jl @@ -10,7 +10,7 @@ function f(du, u, p, t) du .= 2.0 * u end prob = ODEProblem(f, u0, tspan) -@test_throws DiffEqBase.IncompatibleInitialConditionError sol=solve(prob, Tsit5()) +@test_throws SciMLBase.IncompatibleInitialConditionError sol=solve(prob, Tsit5()) prob = ODEProblem{false}(f, u0, tspan) sol = solve(prob, Tsit5()) @@ -18,40 +18,40 @@ sol = solve(prob, nothing, alg = Tsit5()) sol = init(prob, nothing, alg = Tsit5()) prob = ODEProblem{false}(f, 1.0 + im, tspan) -@test_throws DiffEqBase.ComplexSupportError solve(prob, CVODE_Adams()) +@test_throws SciMLBase.ComplexSupportError solve(prob, CVODE_Adams()) -@test_throws DiffEqBase.ProblemSolverPairingError solve(prob, DFBDF()) -@test_throws DiffEqBase.NonSolverError solve(prob, 5.0) +@test_throws SciMLBase.ProblemSolverPairingError solve(prob, DFBDF()) +@test_throws SciMLBase.NonSolverError solve(prob, 5.0) prob = ODEProblem{false}(f, u0, (nothing, nothing)) -@test_throws DiffEqBase.NoTspanError solve(prob, Tsit5()) +@test_throws SciMLBase.NoTspanError solve(prob, Tsit5()) prob = ODEProblem{false}(f, u0, (NaN, 1.0)) -@test_throws DiffEqBase.NaNTspanError solve(prob, Tsit5()) +@test_throws SciMLBase.NaNTspanError solve(prob, Tsit5()) prob = ODEProblem{false}(f, u0, (1.0, NaN)) -@test_throws DiffEqBase.NaNTspanError solve(prob, Tsit5()) +@test_throws SciMLBase.NaNTspanError solve(prob, Tsit5()) prob = ODEProblem{false}(f, Any[1.0, 1.0f0], tspan) -@test_throws DiffEqBase.NonConcreteEltypeError solve(prob, Tsit5()) +@test_throws SciMLBase.NonConcreteEltypeError solve(prob, Tsit5()) prob = ODEProblem{false}(f, (1.0, 1.0f0), tspan) -@test_throws DiffEqBase.TupleStateError solve(prob, Tsit5()) +@test_throws SciMLBase.TupleStateError solve(prob, Tsit5()) prob = ODEProblem{false}(f, u0, (0.0 + im, 1.0)) -@test_throws DiffEqBase.ComplexTspanError solve(prob, Tsit5()) +@test_throws SciMLBase.ComplexTspanError solve(prob, Tsit5()) for u0 in ([0.0, 0.0], nothing) fmm = ODEFunction(f, mass_matrix = zeros(3, 3)) prob = ODEProblem(fmm, u0, (0.0, 1.0)) - @test_throws DiffEqBase.IncompatibleMassMatrixError solve(prob, Tsit5()) + @test_throws SciMLBase.IncompatibleMassMatrixError solve(prob, Tsit5()) end # Allow empty mass matrix for empty u0 fmm = ODEFunction((du, u, t) -> nothing, mass_matrix = zeros(0, 0)) prob = ODEProblem(fmm, nothing, (0.0, 1.0)) sol = solve(prob, Tsit5()) -@test isa(sol, DiffEqBase.ODESolution) +@test isa(sol, SciMLBase.ODESolution) f(du, u, p, t) = du .= 1.01u function g(du, u, p, t) @@ -71,7 +71,7 @@ prob = SDEProblem(f, (0.0, 1.0), noise_rate_prototype = complex(zeros(2, 4)), noise = StochasticDiffEq.RealWienerProcess(0.0, zeros(3))) -@test_throws DiffEqBase.NoiseSizeIncompatabilityError solve(prob, LambaEM()) +@test_throws SciMLBase.NoiseSizeIncompatabilityError solve(prob, LambaEM()) function g!(du, u, p, t) du[1] .= u[1] + ones(3, 3) @@ -79,4 +79,4 @@ function g!(du, u, p, t) end u0 = [zeros(3, 3), zeros(3, 3)] prob = ODEProblem(g!, u0, (0, 1.0)) -@test_throws DiffEqBase.NonNumberEltypeError solve(prob, Tsit5()) +@test_throws SciMLBase.NonNumberEltypeError solve(prob, Tsit5()) From 70d5a97df73873bc668ae12be6b1871629d967c8 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 28 Aug 2025 12:19:26 -0400 Subject: [PATCH 16/17] fix typo --- test/downstream/kwarg_warn.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/downstream/kwarg_warn.jl b/test/downstream/kwarg_warn.jl index 1cdab9881..879971dd9 100644 --- a/test/downstream/kwarg_warn.jl +++ b/test/downstream/kwarg_warn.jl @@ -15,4 +15,4 @@ sol = solve(prob, Tsit5(), rel_tol = 1e-6, kwargshandle = SciMLBase.KeywordArgWa @test_throws SciMLBase.CommonKwargError sol=solve(prob, Tsit5(), rel_tol = 1e-6) prob = ODEProblem(lorenz, u0, tspan, test = 2.0, kwargshandle = SciMLBase.KeywordArgWarn) -@test_logs (:warn, SciMLsBase.KWARGWARN_MESSAGE) sol=solve(prob, Tsit5(), reltol = 1e-6) +@test_logs (:warn, SciMLBase.KWARGWARN_MESSAGE) sol=solve(prob, Tsit5(), reltol = 1e-6) From 2ed36161c8f0986e650aa691ff2617bb068fb433 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 28 Aug 2025 12:43:00 -0400 Subject: [PATCH 17/17] fix unitful tests --- test/downstream/unitful.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/downstream/unitful.jl b/test/downstream/unitful.jl index 89bf38836..68caf0a68 100644 --- a/test/downstream/unitful.jl +++ b/test/downstream/unitful.jl @@ -4,6 +4,6 @@ prob = ODEProblem(f, [2.0u"m"], (0.0u"s", Inf * u"s")) intg = init(prob, Tsit5()) @test_nowarn step!(intg, 0.02u"s", true) -@test DiffEqBase.unitfulvalue(u"1/s") == u"1/s" -@test DiffEqBase.value(ForwardDiff.Dual(1) * u"1/s") == 1 -@test DiffEqBase.unitfulvalue(ForwardDiff.Dual(1) * u"1/s") == u"1/s" +@test SciMLBase.unitfulvalue(u"1/s") == u"1/s" +@test SciMLBase.value(ForwardDiff.Dual(1) * u"1/s") == 1 +@test SciMLBase.unitfulvalue(ForwardDiff.Dual(1) * u"1/s") == u"1/s"