Skip to content

feat: more robust inputs/outputs handling #3795

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
19 changes: 10 additions & 9 deletions src/inputoutput.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using Symbolics: get_variables
Return all variables that mare marked as inputs. See also [`unbound_inputs`](@ref)
See also [`bound_inputs`](@ref), [`unbound_inputs`](@ref)
"""
inputs(sys) = [filter(isinput, unknowns(sys)); filter(isinput, parameters(sys))]
inputs(sys) = collect(get_inputs(sys))

"""
outputs(sys)
Expand All @@ -14,13 +14,7 @@ Return all variables that mare marked as outputs. See also [`unbound_outputs`](@
See also [`bound_outputs`](@ref), [`unbound_outputs`](@ref)
"""
function outputs(sys)
o = observed(sys)
rhss = [eq.rhs for eq in o]
lhss = [eq.lhs for eq in o]
unique([filter(isoutput, unknowns(sys))
filter(isoutput, parameters(sys))
filter(x -> iscall(x) && isoutput(x), rhss) # observed can return equations with complicated expressions, we are only looking for single Terms
filter(x -> iscall(x) && isoutput(x), lhss)])
return collect(get_outputs(sys))
end

"""
Expand Down Expand Up @@ -288,7 +282,12 @@ function inputs_to_parameters!(state::TransformationState, inputsyms)
push!(new_fullvars, v)
end
end
ninputs == 0 && return state
if ninputs == 0
@set! sys.inputs = OrderedSet{BasicSymbolic}()
@set! sys.outputs = OrderedSet{BasicSymbolic}(filter(isoutput, fullvars))
state.sys = sys
return state
end

nvars = ndsts(graph) - ninputs
new_graph = BipartiteGraph(nsrcs(graph), nvars, Val(false))
Expand Down Expand Up @@ -318,6 +317,8 @@ function inputs_to_parameters!(state::TransformationState, inputsyms)
ps = parameters(sys)

@set! sys.ps = [ps; new_parameters]
@set! sys.inputs = OrderedSet{BasicSymbolic}(new_parameters)
@set! sys.outputs = OrderedSet{BasicSymbolic}(filter(isoutput, fullvars))
@set! state.sys = sys
@set! state.fullvars = Vector{BasicSymbolic}(new_fullvars)
@set! state.structure = structure
Expand Down
3 changes: 1 addition & 2 deletions src/linearization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -572,8 +572,7 @@ struct IONotFoundError <: Exception
end

function Base.showerror(io::IO, err::IONotFoundError)
println(io,
"The following $(err.variant) provided to `mtkcompile` were not found in the system:")
println(io, "The following $(err.variant) provided to `mtkcompile` were not found in the system:")
maybe_namespace_issue = false
for var in err.not_found
println(io, " ", var)
Expand Down
13 changes: 13 additions & 0 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,8 @@ const SYS_PROPS = [:eqs
:parent
:is_dde
:tstops
:inputs
:outputs
:index_cache
:isscheduled
:costs
Expand Down Expand Up @@ -1820,6 +1822,17 @@ function push_vars!(stmt, name, typ, vars)
ex = nameof(s)
end
push!(vars_expr.args, ex)

meta_kvps = Expr[]
if isinput(s)
push!(meta_kvps, :(input = true))
end
if isoutput(s)
push!(meta_kvps, :(output = true))
end
if !isempty(meta_kvps)
push!(vars_expr.args, Expr(:vect, meta_kvps...))
end
end
push!(stmt, :($name = $collect($vars_expr)))
return
Expand Down
49 changes: 43 additions & 6 deletions src/systems/system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,16 @@ struct System <: IntermediateDeprecationSystem
"""
tstops::Vector{Any}
"""
$INTERNAL_FIELD_WARNING
The list of input variables of the system.
"""
inputs::OrderedSet{BasicSymbolic}
"""
$INTERNAL_FIELD_WARNING
The list of output variables of the system.
"""
outputs::OrderedSet{BasicSymbolic}
"""
The `TearingState` of the system post-simplification with `mtkcompile`.
"""
tearing_state::Any
Expand Down Expand Up @@ -255,8 +265,9 @@ struct System <: IntermediateDeprecationSystem
brownians, iv, observed, parameter_dependencies, var_to_name, name, description,
defaults, guesses, systems, initialization_eqs, continuous_events, discrete_events,
connector_type, assertions = Dict{BasicSymbolic, String}(),
metadata = MetadataT(), gui_metadata = nothing,
is_dde = false, tstops = [], tearing_state = nothing, namespacing = true,
metadata = MetadataT(), gui_metadata = nothing, is_dde = false, tstops = [],
inputs = Set{BasicSymbolic}(), outputs = Set{BasicSymbolic}(),
tearing_state = nothing, namespacing = true,
complete = false, index_cache = nothing, ignored_connections = nothing,
preface = nothing, parent = nothing, initializesystem = nothing,
is_initializesystem = false, is_discrete = false, isscheduled = false,
Expand Down Expand Up @@ -296,7 +307,8 @@ struct System <: IntermediateDeprecationSystem
observed, parameter_dependencies, var_to_name, name, description, defaults,
guesses, systems, initialization_eqs, continuous_events, discrete_events,
connector_type, assertions, metadata, gui_metadata, is_dde,
tstops, tearing_state, namespacing, complete, index_cache, ignored_connections,
tstops, inputs, outputs, tearing_state, namespacing,
complete, index_cache, ignored_connections,
preface, parent, initializesystem, is_initializesystem, is_discrete,
isscheduled, schedule)
end
Expand Down Expand Up @@ -332,7 +344,8 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = [];
continuous_events = SymbolicContinuousCallback[], discrete_events = SymbolicDiscreteCallback[],
connector_type = nothing, assertions = Dict{BasicSymbolic, String}(),
metadata = MetadataT(), gui_metadata = nothing,
is_dde = nothing, tstops = [], tearing_state = nothing,
is_dde = nothing, tstops = [], inputs = OrderedSet{BasicSymbolic}(),
outputs = OrderedSet{BasicSymbolic}(), tearing_state = nothing,
ignored_connections = nothing, parent = nothing,
description = "", name = nothing, discover_from_metadata = true,
initializesystem = nothing, is_initializesystem = false, is_discrete = false,
Expand Down Expand Up @@ -367,15 +380,35 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = [];

defaults = anydict(defaults)
guesses = anydict(guesses)
inputs = OrderedSet{BasicSymbolic}(inputs)
outputs = OrderedSet{BasicSymbolic}(outputs)
for subsys in systems
for var in ModelingToolkit.inputs(subsys)
push!(inputs, renamespace(subsys, var))
end
for var in ModelingToolkit.outputs(subsys)
push!(outputs, renamespace(subsys, var))
end
end
var_to_name = anydict()

let defaults = discover_from_metadata ? defaults : Dict(),
guesses = discover_from_metadata ? guesses : Dict()
guesses = discover_from_metadata ? guesses : Dict(),
inputs = discover_from_metadata ? inputs : Set(),
outputs = discover_from_metadata ? outputs : Set()

process_variables!(var_to_name, defaults, guesses, dvs)
process_variables!(var_to_name, defaults, guesses, ps)
process_variables!(var_to_name, defaults, guesses, [eq.lhs for eq in observed])
process_variables!(var_to_name, defaults, guesses, [eq.rhs for eq in observed])

for var in dvs
if isinput(var)
push!(inputs, var)
elseif isoutput(var)
push!(outputs, var)
end
end
end
filter!(!(isnothing ∘ last), defaults)
filter!(!(isnothing ∘ last), guesses)
Expand Down Expand Up @@ -417,7 +450,8 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = [];
costs, consolidate, dvs, ps, brownians, iv, observed, Equation[],
var_to_name, name, description, defaults, guesses, systems, initialization_eqs,
continuous_events, discrete_events, connector_type, assertions, metadata, gui_metadata, is_dde,
tstops, tearing_state, true, false, nothing, ignored_connections, preface, parent,
tstops, inputs, outputs, tearing_state, true, false,
nothing, ignored_connections, preface, parent,
initializesystem, is_initializesystem, is_discrete; checks)
end

Expand Down Expand Up @@ -731,6 +765,7 @@ function flatten(sys::System, noeqs = false)
discrete_events = discrete_events(sys), assertions = assertions(sys),
is_dde = is_dde(sys), tstops = symbolic_tstops(sys),
initialization_eqs = initialization_equations(sys),
inputs = inputs(sys), outputs = outputs(sys),
# without this, any defaults/guesses obtained from metadata that were
# later removed by the user will be re-added. Right now, we just want to
# retain `defaults(sys)` as-is.
Expand Down Expand Up @@ -1143,6 +1178,8 @@ function Base.isapprox(sysa::System, sysb::System)
isequal(get_metadata(sysa), get_metadata(sysb)) &&
isequal(get_is_dde(sysa), get_is_dde(sysb)) &&
issetequal(get_tstops(sysa), get_tstops(sysb)) &&
issetequal(get_inputs(sysa), get_inputs(sysb)) &&
issetequal(get_outputs(sysa), get_outputs(sysb)) &&
safe_issetequal(get_ignored_connections(sysa), get_ignored_connections(sysb)) &&
isequal(get_is_initializesystem(sysa), get_is_initializesystem(sysb)) &&
isequal(get_is_discrete(sysa), get_is_discrete(sysb)) &&
Expand Down
8 changes: 2 additions & 6 deletions src/systems/systemstructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -993,13 +993,9 @@ function _mtkcompile!(state::TearingState; simplify = false,
else
check_consistency = true
end
has_io = !isempty(inputs) || !isempty(outputs) !== nothing ||
!isempty(disturbance_inputs)
orig_inputs = Set()
if has_io
ModelingToolkit.markio!(state, orig_inputs, inputs, outputs, disturbance_inputs)
state = ModelingToolkit.inputs_to_parameters!(state, [inputs; disturbance_inputs])
end
ModelingToolkit.markio!(state, orig_inputs, inputs, outputs, disturbance_inputs)
state = ModelingToolkit.inputs_to_parameters!(state, [inputs; disturbance_inputs])
trivial_tearing!(state)
sys, mm = ModelingToolkit.alias_elimination!(state; fully_determined, kwargs...)
if check_consistency
Expand Down
17 changes: 17 additions & 0 deletions test/input_output_handling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -468,3 +468,20 @@ end
x = [1.0]
@test_nowarn f[1](x, u, p, 0.0)
end

@testset "Observed inputs and outputs" begin
@variables x(t) y(t) [input = true] z(t) [output = true]
eqs = [D(x) ~ x + y + z
y ~ z]
@named sys = System(eqs, t)
@test issetequal(ModelingToolkit.inputs(sys), [y])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the PR that should address the ordering an number of inputs after mtkcompile(..., inputs = [u1,u2])? If so, it would be nice to include a test case where this failed
https://github.com/JuliaComputing/DyadControlSystems.jl/actions/runs/17095167938/job/48477801369?pr=644#step:7:1727

I notice also that the build that was triggered still fails in almost the same place, but it now has a different error message than it had before

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ordering was already fixed in #3804. We have an assertion for the ordering being maintained in complete, so if by chance something goes wrong it will surface as an error. I don't know how we would test it, since there isn't a specific condition where the ordering goes wrong. It's just that if we're not careful it might get shuffled around, which is what the assertion detects.

The new failure in DyadControlSystems seems to me like a missing splat operation. It's trying to pass a NamedTuple of 4 matrices where the function expects them as 4 different arguments.

Copy link
Contributor

@baggepinnen baggepinnen Aug 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, so it has been a change in MTKv10 that went unnoticed before this fix then, I'll fix it separately.

@test issetequal(ModelingToolkit.outputs(sys), [z])

ss1 = mtkcompile(sys, inputs = [y], outputs = [z])
@test issetequal(ModelingToolkit.inputs(ss1), [y])
@test issetequal(ModelingToolkit.outputs(ss1), [z])

ss2 = mtkcompile(sys, inputs = [z], outputs = [y])
@test issetequal(ModelingToolkit.inputs(ss2), [z])
@test issetequal(ModelingToolkit.outputs(ss2), [y])
end
Loading