Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 60 additions & 17 deletions src/ansys/dpf/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

"""Workflow."""

from __future__ import annotations

from enum import Enum
import logging
import os
Expand Down Expand Up @@ -713,8 +715,41 @@
out.append(self._api.work_flow_output_by_index(self, i))
return out

def safe_connect_with(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

safe may be misleading, as to me it would mean 'if you use this function, you can be sure that what you asked for has been checked and is valid', not 'I'll only connect what can be and ignore the rest'.

Something like soft_connect_with maybe?

self,
left_workflow: Workflow,
output_input_names: Union[tuple[str, str], dict[str, str]],
):
"""Prepend a given workflow to the current workflow for valid connections only.
See Workflow.connect_with for more information on the connection logic.
Parameters
----------
left_workflow:
The given workflow's outputs are chained with the current workflow's inputs.
output_input_names:
Map used to connect the outputs of the given workflow to the inputs of the current
workflow.
Check the names of available inputs and outputs for each workflow using
`Workflow.input_names` and `Workflow.output_names`.
"""
if isinstance(output_input_names, tuple):
output_input_names = {output_input_names[0]: output_input_names[1]}
valid_connections = dict(

Check warning on line 739 in src/ansys/dpf/core/workflow.py

View check run for this annotation

Codecov / codecov/patch

src/ansys/dpf/core/workflow.py#L737-L739

Added lines #L737 - L739 were not covered by tests
filter(
lambda item: item[0] in left_workflow.output_names and item[1] in self.input_names,
output_input_names.items(),
)
)
self.connect_with(left_workflow, valid_connections)

Check warning on line 745 in src/ansys/dpf/core/workflow.py

View check run for this annotation

Codecov / codecov/patch

src/ansys/dpf/core/workflow.py#L745

Added line #L745 was not covered by tests

@version_requires("3.0")
def connect_with(self, left_workflow, output_input_names=None):
def connect_with(
self,
left_workflow: Workflow,
output_input_names: Union[tuple[str, str], dict[str, str]] = None,
):
"""Prepend a given workflow to the current workflow.
Updates the current workflow to include all the operators of the workflow given as argument.
Expand All @@ -724,9 +759,9 @@
Parameters
----------
left_workflow : core.Workflow
left_workflow:
The given workflow's outputs are chained with the current workflow's inputs.
output_input_names : str tuple, str dict optional
output_input_names:
Map used to connect the outputs of the given workflow to the inputs of the current
workflow.
Check the names of available inputs and outputs for each workflow using
Expand Down Expand Up @@ -791,24 +826,32 @@
"""
if output_input_names:
core_api = self._server.get_api_for_type(
capi=data_processing_capi.DataProcessingCAPI,
grpcapi=data_processing_grpcapi.DataProcessingGRPCAPI,
)
map = object_handler.ObjHandler(
data_processing_api=core_api,
internal_obj=self._api.workflow_create_connection_map_for_object(self),
)
if isinstance(output_input_names, tuple):
self._api.workflow_add_entry_connection_map(
map, output_input_names[0], output_input_names[1]
output_input_names = {output_input_names[0]: output_input_names[1]}
if isinstance(output_input_names, dict):
core_api = self._server.get_api_for_type(

Check warning on line 832 in src/ansys/dpf/core/workflow.py

View check run for this annotation

Codecov / codecov/patch

src/ansys/dpf/core/workflow.py#L830-L832

Added lines #L830 - L832 were not covered by tests
capi=data_processing_capi.DataProcessingCAPI,
grpcapi=data_processing_grpcapi.DataProcessingGRPCAPI,
)
map = object_handler.ObjHandler(

Check warning on line 836 in src/ansys/dpf/core/workflow.py

View check run for this annotation

Codecov / codecov/patch

src/ansys/dpf/core/workflow.py#L836

Added line #L836 was not covered by tests
data_processing_api=core_api,
internal_obj=self._api.workflow_create_connection_map_for_object(self),
)
elif isinstance(output_input_names, dict):
for key in output_input_names:
self._api.workflow_add_entry_connection_map(map, key, output_input_names[key])
output_names = left_workflow.output_names
input_names = self.input_names
for output_name, input_name in output_input_names.items():
if output_name not in output_names:
raise ValueError(

Check warning on line 844 in src/ansys/dpf/core/workflow.py

View check run for this annotation

Codecov / codecov/patch

src/ansys/dpf/core/workflow.py#L840-L844

Added lines #L840 - L844 were not covered by tests
f"Cannot connect workflow output '{output_name}'. Exposed outputs are:\n{output_names}"
)
elif input_name not in input_names:
raise ValueError(

Check warning on line 848 in src/ansys/dpf/core/workflow.py

View check run for this annotation

Codecov / codecov/patch

src/ansys/dpf/core/workflow.py#L847-L848

Added lines #L847 - L848 were not covered by tests
f"Cannot connect workflow input '{input_name}'. Exposed inputs are:\n{input_names}"
)
self._api.workflow_add_entry_connection_map(map, output_name, input_name)

Check warning on line 851 in src/ansys/dpf/core/workflow.py

View check run for this annotation

Codecov / codecov/patch

src/ansys/dpf/core/workflow.py#L851

Added line #L851 was not covered by tests
else:
raise TypeError(
"output_input_names argument is expect" "to be either a str tuple or a str dict"
"output_input_names argument is expected to be either a str tuple or a str dict"
)
self._api.work_flow_connect_with_specified_names(self, left_workflow, map)
else:
Expand Down
36 changes: 36 additions & 0 deletions tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,42 @@ def test_connect_with_dict_workflow(cyclic_lin_rst, cyclic_ds, server_type):
fc = wf2.get_output("u", dpf.core.types.fields_container)


def test_workflow_connect_raise_wrong_label(server_type):
workflow1 = dpf.core.Workflow()
forward_1 = dpf.core.operators.utility.forward()
workflow1.set_output_name("output", forward_1.outputs.any)

workflow2 = dpf.core.Workflow()
forward_2 = dpf.core.operators.utility.forward()
workflow2.set_input_name("input", forward_2.inputs.any)

with pytest.raises(
ValueError, match="Cannot connect workflow output 'out'. Exposed outputs are:\n"
):
workflow2.connect_with(workflow1, output_input_names={"out": "input"})
with pytest.raises(
ValueError, match="Cannot connect workflow input 'in'. Exposed inputs are:\n"
):
workflow2.connect_with(workflow1, output_input_names={"output": "in"})
workflow2.connect_with(workflow1, output_input_names={"output": "input"})


def test_workflow_safe_connect_with(server_type):
workflow1 = dpf.core.Workflow()
forward_1 = dpf.core.operators.utility.forward()
workflow1.set_output_name("output", forward_1.outputs.any)

workflow2 = dpf.core.Workflow()
forward_2 = dpf.core.operators.utility.forward()
workflow2.set_input_name("input", forward_2.inputs.any)

workflow2.safe_connect_with(workflow1, output_input_names={"out": "input"})

workflow2.safe_connect_with(workflow1, output_input_names={"output": "in"})

workflow2.safe_connect_with(workflow1, output_input_names=("output", "input"))


@pytest.mark.xfail(raises=dpf.core.errors.ServerTypeError)
def test_info_workflow(allkindofcomplexity, server_type):
model = dpf.core.Model(allkindofcomplexity, server=server_type)
Expand Down
Loading