diff --git a/src/inputoutput.jl b/src/inputoutput.jl index 7beb86cf74..c113c4e753 100644 --- a/src/inputoutput.jl +++ b/src/inputoutput.jl @@ -5,7 +5,7 @@ using Symbolics: get_variables Return all variables that mare marked as inputs. See also [`unbound_inputs`](@ref) See also [`bound_inputs`](@ref), [`unbound_inputs`](@ref) """ -inputs(sys) = [filter(isinput, unknowns(sys)); filter(isinput, parameters(sys))] +inputs(sys) = collect(get_inputs(sys)) """ outputs(sys) @@ -14,13 +14,7 @@ Return all variables that mare marked as outputs. See also [`unbound_outputs`](@ See also [`bound_outputs`](@ref), [`unbound_outputs`](@ref) """ function outputs(sys) - o = observed(sys) - rhss = [eq.rhs for eq in o] - lhss = [eq.lhs for eq in o] - unique([filter(isoutput, unknowns(sys)) - filter(isoutput, parameters(sys)) - filter(x -> iscall(x) && isoutput(x), rhss) # observed can return equations with complicated expressions, we are only looking for single Terms - filter(x -> iscall(x) && isoutput(x), lhss)]) + return collect(get_outputs(sys)) end """ @@ -288,7 +282,12 @@ function inputs_to_parameters!(state::TransformationState, inputsyms) push!(new_fullvars, v) end end - ninputs == 0 && return state + if ninputs == 0 + @set! sys.inputs = OrderedSet{BasicSymbolic}() + @set! sys.outputs = OrderedSet{BasicSymbolic}(filter(isoutput, fullvars)) + state.sys = sys + return state + end nvars = ndsts(graph) - ninputs new_graph = BipartiteGraph(nsrcs(graph), nvars, Val(false)) @@ -318,6 +317,8 @@ function inputs_to_parameters!(state::TransformationState, inputsyms) ps = parameters(sys) @set! sys.ps = [ps; new_parameters] + @set! sys.inputs = OrderedSet{BasicSymbolic}(new_parameters) + @set! sys.outputs = OrderedSet{BasicSymbolic}(filter(isoutput, fullvars)) @set! state.sys = sys @set! state.fullvars = Vector{BasicSymbolic}(new_fullvars) @set! state.structure = structure diff --git a/src/linearization.jl b/src/linearization.jl index 5c0c174cdc..3c11484d61 100644 --- a/src/linearization.jl +++ b/src/linearization.jl @@ -572,8 +572,7 @@ struct IONotFoundError <: Exception end function Base.showerror(io::IO, err::IONotFoundError) - println(io, - "The following $(err.variant) provided to `mtkcompile` were not found in the system:") + println(io, "The following $(err.variant) provided to `mtkcompile` were not found in the system:") maybe_namespace_issue = false for var in err.not_found println(io, " ", var) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index f9d9a196b4..0bd05bb4b9 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -784,6 +784,8 @@ const SYS_PROPS = [:eqs :parent :is_dde :tstops + :inputs + :outputs :index_cache :isscheduled :costs @@ -1820,6 +1822,17 @@ function push_vars!(stmt, name, typ, vars) ex = nameof(s) end push!(vars_expr.args, ex) + + meta_kvps = Expr[] + if isinput(s) + push!(meta_kvps, :(input = true)) + end + if isoutput(s) + push!(meta_kvps, :(output = true)) + end + if !isempty(meta_kvps) + push!(vars_expr.args, Expr(:vect, meta_kvps...)) + end end push!(stmt, :($name = $collect($vars_expr))) return diff --git a/src/systems/system.jl b/src/systems/system.jl index 08421e04cc..6db36ebd36 100644 --- a/src/systems/system.jl +++ b/src/systems/system.jl @@ -190,6 +190,16 @@ struct System <: IntermediateDeprecationSystem """ tstops::Vector{Any} """ + $INTERNAL_FIELD_WARNING + The list of input variables of the system. + """ + inputs::OrderedSet{BasicSymbolic} + """ + $INTERNAL_FIELD_WARNING + The list of output variables of the system. + """ + outputs::OrderedSet{BasicSymbolic} + """ The `TearingState` of the system post-simplification with `mtkcompile`. """ tearing_state::Any @@ -255,8 +265,9 @@ struct System <: IntermediateDeprecationSystem brownians, iv, observed, parameter_dependencies, var_to_name, name, description, defaults, guesses, systems, initialization_eqs, continuous_events, discrete_events, connector_type, assertions = Dict{BasicSymbolic, String}(), - metadata = MetadataT(), gui_metadata = nothing, - is_dde = false, tstops = [], tearing_state = nothing, namespacing = true, + metadata = MetadataT(), gui_metadata = nothing, is_dde = false, tstops = [], + inputs = Set{BasicSymbolic}(), outputs = Set{BasicSymbolic}(), + tearing_state = nothing, namespacing = true, complete = false, index_cache = nothing, ignored_connections = nothing, preface = nothing, parent = nothing, initializesystem = nothing, is_initializesystem = false, is_discrete = false, isscheduled = false, @@ -296,7 +307,8 @@ struct System <: IntermediateDeprecationSystem observed, parameter_dependencies, var_to_name, name, description, defaults, guesses, systems, initialization_eqs, continuous_events, discrete_events, connector_type, assertions, metadata, gui_metadata, is_dde, - tstops, tearing_state, namespacing, complete, index_cache, ignored_connections, + tstops, inputs, outputs, tearing_state, namespacing, + complete, index_cache, ignored_connections, preface, parent, initializesystem, is_initializesystem, is_discrete, isscheduled, schedule) end @@ -332,7 +344,8 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = []; continuous_events = SymbolicContinuousCallback[], discrete_events = SymbolicDiscreteCallback[], connector_type = nothing, assertions = Dict{BasicSymbolic, String}(), metadata = MetadataT(), gui_metadata = nothing, - is_dde = nothing, tstops = [], tearing_state = nothing, + is_dde = nothing, tstops = [], inputs = OrderedSet{BasicSymbolic}(), + outputs = OrderedSet{BasicSymbolic}(), tearing_state = nothing, ignored_connections = nothing, parent = nothing, description = "", name = nothing, discover_from_metadata = true, initializesystem = nothing, is_initializesystem = false, is_discrete = false, @@ -367,15 +380,35 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = []; defaults = anydict(defaults) guesses = anydict(guesses) + inputs = OrderedSet{BasicSymbolic}(inputs) + outputs = OrderedSet{BasicSymbolic}(outputs) + for subsys in systems + for var in ModelingToolkit.inputs(subsys) + push!(inputs, renamespace(subsys, var)) + end + for var in ModelingToolkit.outputs(subsys) + push!(outputs, renamespace(subsys, var)) + end + end var_to_name = anydict() let defaults = discover_from_metadata ? defaults : Dict(), - guesses = discover_from_metadata ? guesses : Dict() + guesses = discover_from_metadata ? guesses : Dict(), + inputs = discover_from_metadata ? inputs : Set(), + outputs = discover_from_metadata ? outputs : Set() process_variables!(var_to_name, defaults, guesses, dvs) process_variables!(var_to_name, defaults, guesses, ps) process_variables!(var_to_name, defaults, guesses, [eq.lhs for eq in observed]) process_variables!(var_to_name, defaults, guesses, [eq.rhs for eq in observed]) + + for var in dvs + if isinput(var) + push!(inputs, var) + elseif isoutput(var) + push!(outputs, var) + end + end end filter!(!(isnothing ∘ last), defaults) filter!(!(isnothing ∘ last), guesses) @@ -417,7 +450,8 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = []; costs, consolidate, dvs, ps, brownians, iv, observed, Equation[], var_to_name, name, description, defaults, guesses, systems, initialization_eqs, continuous_events, discrete_events, connector_type, assertions, metadata, gui_metadata, is_dde, - tstops, tearing_state, true, false, nothing, ignored_connections, preface, parent, + tstops, inputs, outputs, tearing_state, true, false, + nothing, ignored_connections, preface, parent, initializesystem, is_initializesystem, is_discrete; checks) end @@ -731,6 +765,7 @@ function flatten(sys::System, noeqs = false) discrete_events = discrete_events(sys), assertions = assertions(sys), is_dde = is_dde(sys), tstops = symbolic_tstops(sys), initialization_eqs = initialization_equations(sys), + inputs = inputs(sys), outputs = outputs(sys), # without this, any defaults/guesses obtained from metadata that were # later removed by the user will be re-added. Right now, we just want to # retain `defaults(sys)` as-is. @@ -1143,6 +1178,8 @@ function Base.isapprox(sysa::System, sysb::System) isequal(get_metadata(sysa), get_metadata(sysb)) && isequal(get_is_dde(sysa), get_is_dde(sysb)) && issetequal(get_tstops(sysa), get_tstops(sysb)) && + issetequal(get_inputs(sysa), get_inputs(sysb)) && + issetequal(get_outputs(sysa), get_outputs(sysb)) && safe_issetequal(get_ignored_connections(sysa), get_ignored_connections(sysb)) && isequal(get_is_initializesystem(sysa), get_is_initializesystem(sysb)) && isequal(get_is_discrete(sysa), get_is_discrete(sysb)) && diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index 5e33a2da5a..a1460731cb 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -993,13 +993,9 @@ function _mtkcompile!(state::TearingState; simplify = false, else check_consistency = true end - has_io = !isempty(inputs) || !isempty(outputs) !== nothing || - !isempty(disturbance_inputs) orig_inputs = Set() - if has_io - ModelingToolkit.markio!(state, orig_inputs, inputs, outputs, disturbance_inputs) - state = ModelingToolkit.inputs_to_parameters!(state, [inputs; disturbance_inputs]) - end + ModelingToolkit.markio!(state, orig_inputs, inputs, outputs, disturbance_inputs) + state = ModelingToolkit.inputs_to_parameters!(state, [inputs; disturbance_inputs]) trivial_tearing!(state) sys, mm = ModelingToolkit.alias_elimination!(state; fully_determined, kwargs...) if check_consistency diff --git a/test/input_output_handling.jl b/test/input_output_handling.jl index 6d97cf8198..cf02bc8d1e 100644 --- a/test/input_output_handling.jl +++ b/test/input_output_handling.jl @@ -468,3 +468,20 @@ end x = [1.0] @test_nowarn f[1](x, u, p, 0.0) end + +@testset "Observed inputs and outputs" begin + @variables x(t) y(t) [input = true] z(t) [output = true] + eqs = [D(x) ~ x + y + z + y ~ z] + @named sys = System(eqs, t) + @test issetequal(ModelingToolkit.inputs(sys), [y]) + @test issetequal(ModelingToolkit.outputs(sys), [z]) + + ss1 = mtkcompile(sys, inputs = [y], outputs = [z]) + @test issetequal(ModelingToolkit.inputs(ss1), [y]) + @test issetequal(ModelingToolkit.outputs(ss1), [z]) + + ss2 = mtkcompile(sys, inputs = [z], outputs = [y]) + @test issetequal(ModelingToolkit.inputs(ss2), [z]) + @test issetequal(ModelingToolkit.outputs(ss2), [y]) +end