diff --git a/src/ansys/motorcad/core/geometry.py b/src/ansys/motorcad/core/geometry.py index 7662dd530..4085823f3 100644 --- a/src/ansys/motorcad/core/geometry.py +++ b/src/ansys/motorcad/core/geometry.py @@ -61,6 +61,7 @@ class RegionType(Enum): rotor_pocket = "Rotor Pocket" pole_spacer = "Pole Spacer" rotor_slot = "Rotor Slot" + rotor_bar_end_ring = "Rotor Bar End Ring" coil_separator = "Coil Separator" damper_bar = "Damper Bar" wedge_rotor = "Rotor Wedge" @@ -75,6 +76,7 @@ class RegionType(Enum): barrier = "Barrier" mounting_base = "Base Mount" mounting_plate = "Plate Mount" + endcap = "Endcap" banding = "Banding" sleeve = "Sleeve" rotor_cover = "Rotor Cover" @@ -84,7 +86,9 @@ class RegionType(Enum): slot_wj_duct_no_detail = "Slot Water Jacket Duct (no detail)" cowling = "Cowling" cowling_gril = "Cowling Grill" + cowling_grill_hole = "Cowling Grill Hole" brush = "Brush" + bearings = "Bearings" commutator = "Commutator" airgap = "Airgap" dxf_import = "DXF Import" @@ -1589,6 +1593,17 @@ def get_arc_intersection(self, arc): """ return arc.get_line_intersection(self) + def get_coordinate_distance(self, coordinate): + """Get distance of line with another coordinate.""" + normal_angle = self.angle - 90 + defining_point = Coordinate.from_polar_coords(1, normal_angle) + normal = Line(Coordinate(0, 0), defining_point) + normal.translate(coordinate.x, coordinate.y) + nearest_point = self.get_line_intersection(normal) + if nearest_point is None: + return None + return sqrt((coordinate.x - nearest_point.x) ** 2 + (coordinate.y - nearest_point.y) ** 2) + class _BaseArc(Entity): """Internal class to allow creation of Arcs.""" diff --git a/src/ansys/motorcad/core/methods/adaptive_geometry.py b/src/ansys/motorcad/core/methods/adaptive_geometry.py index db363aea9..7dc97656d 100644 --- a/src/ansys/motorcad/core/methods/adaptive_geometry.py +++ b/src/ansys/motorcad/core/methods/adaptive_geometry.py @@ -24,6 +24,7 @@ from warnings import warn from ansys.motorcad.core.geometry import Region, RegionMagnet +from ansys.motorcad.core.methods.geometry_tree import GeometryTree from ansys.motorcad.core.rpc_client_core import MotorCADError, is_running_in_internal_scripting @@ -304,3 +305,15 @@ def reset_adaptive_geometry(self): # No need to do this if running internally if not is_running_in_internal_scripting(): return self.connection.send_and_receive(method) + + def get_geometry_tree(self): + """Fetch a GeometryTree object containing all the defining geometry of the loaded motor.""" + method = "GetGeometryTree" + json = self.connection.send_and_receive(method) + return GeometryTree._from_json(json, self) + + def set_geometry_tree(self, tree: GeometryTree): + """Use a GeometryTree object to set the defining geometry of the loaded motor.""" + params = [tree._to_json()] + method = "SetGeometryTree" + return self.connection.send_and_receive(method, params) diff --git a/src/ansys/motorcad/core/methods/geometry_tree.py b/src/ansys/motorcad/core/methods/geometry_tree.py new file mode 100644 index 000000000..37085f20e --- /dev/null +++ b/src/ansys/motorcad/core/methods/geometry_tree.py @@ -0,0 +1,574 @@ +# Copyright (C) 2022 - 2025 ANSYS, Inc. and/or its affiliates. +# SPDX-License-Identifier: MIT +# +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Methods for building geometry trees.""" + +from ansys.motorcad.core.geometry import Arc, Coordinate, Line, Region, RegionMagnet, RegionType + + +class GeometryTree(dict): + """Class used to build geometry trees.""" + + def __init__(self, empty=False, mc=None): + """Initialise the geometry tree. + + Parameters + ---------- + empty: bool + Return an empty geometry tree, mostly used for the purposes of debugging and + internal construction. + """ + if empty: + super().__init__() + + else: + root = GeometryNode(region_type=RegionType.airgap) + root.parent = None + root.children = list() + root.name = "root" + root.key = "root" + pair = [("root", root)] + super().__init__(pair) + + self._motorcad_instance = mc + + def __iter__(self): + """Define ordering according to tree structure.""" + well_ordered = [] + + def dive(node=self.start): + well_ordered.append(node) + for child in node.children: + dive(child) + + dive() + return iter(well_ordered) + + def __str__(self): + """Return string representation of the geometry tree.""" + string = "" + starting_depth = list(self.values())[0].depth + + for node in self: + relative_depth = node.depth - starting_depth + string += "│ " * (relative_depth - 1) + if relative_depth == 0: + cap = "" + elif node == node.parent.children[-1]: + cap = "└── " + else: + cap = "├── " + string += cap + string += node.key + string += "\n" + return string + + def __eq__(self, other): + """Define equality operator. + + Equality for trees requires both trees have the same structure. + Also requires that each node with the same key is equal. + """ + + def dive(key): + if len(self[key].children) != len(other[key].children): + return False + for child in self[key].children: + # Making sure each child is in the other corresponding node's children + if not child.key in other[key].child_keys: + return False + if not dive(child.key): + return False + # Actual equality check of nodes + if self[key] != other[key]: + return False + return True + + return dive(self.start.key) + + def __ne__(self, other): + """Define inequality.""" + return not self.__eq__(other) + + @classmethod + def _from_json(cls, tree, mc): + """Return a GeometryTree representation of the geometry defined within a JSON. + + Parameters + ---------- + tree: dict + JSON to create a tree from (generally, the output of get_geometry_tree()). + Returns + ------- + GeometryTree + """ + self = cls(empty=True) + """Initialize tree. + + Parameters + ---------- + tree: dict + """ + root = dict() + root["name_unique"] = "root" + root["parent_name"] = "" + root["child_names"] = list() + tree_json = tree["regions"] + + # properly connect the json file before tree is constructed + for region in tree_json.values(): + if region["parent_name"] == "": + region["parent_name"] = "root" + root["child_names"].append(region["name_unique"]) + + self._build_tree(tree_json, root, mc) + self._motorcad_instance = mc + return self + + def _to_json(self): + """Return a dict object used to set geometry.""" + regions = dict() + for node in self: + if node.key != "root": + if node.region_type == "Magnet": + regions[node.key] = RegionMagnet._to_json(node) + else: + regions[node.key] = Region._to_json(node) + return {"regions": regions} + + def get_node(self, key): + """Get a region from the tree (case-insensitive).""" + if isinstance(key, str): + if key.lower() in self.lowercase_keys: + lower_key = key.lower() + return self[self.lowercase_keys[lower_key]] + raise KeyError() + elif isinstance(key, GeometryNode): + return key + else: + raise TypeError("key must be a string or GeometryNode") + + def get_subtree(self, node): + """Get all GeometryTree consisting of all nodes descended from the supplied one.""" + node = self.get_node(node) + if node.key == "root": + return self + subtree = GeometryTree(empty=True) + + def dive(node): + subtree[node.key] = node + for child in node.children: + dive(child) + + dive(node) + return subtree + + def get_nodes_from_type(self, node_type): + """Return all nodes in the tree of the supplied region type. + + Parameters + ---------- + node_type: str or RegionType + Region type to be fetched + """ + if isinstance(node_type, RegionType): + node_type = node_type.value + nodes = [] + for node in self: + if node.region_type.value == node_type and node.key != "root": + nodes.append(node) + return nodes + + def _build_tree(self, tree_json, node, mc, parent=None): + """Recursively builds tree. + + Parameters + ---------- + tree_json: dict + Dictionary containing region dicts + node: dict + Information of current region + parent: None or GeometryNode + """ + # Convert current node to GeometryNode and add it to tree + self[node["name_unique"]] = GeometryNode.from_json(node, parent, mc) + + # Recur for each child. + if node["child_names"] != []: + for child_name in node["child_names"]: + self._build_tree(tree_json, tree_json[child_name], mc, self[node["name_unique"]]) + + def fix_duct_geometry(self, node): + """Fix geometry to work with FEA. + + Check if a region crosses over its upper or lower duplication angle, and splits it + apart into two regions within the valid sector. Meant primarily for ducts; splitting + apart magnet or other regions in this way can result in errors when solving. + + Parameters + ---------- + node: node representing region to be fixed + + Returns + ------- + Bool: bool representing whether splitting occurred + """ + # Splits regions apart, if necessary, to enforce valid geometry + node = self.get_node(node) + name_length = len(node.key) + duplication_angle = 360 / node.duplications + + # brush1 used to find the valid portion just above angle 0 + brush1 = Region(region_type=RegionType.airgap) + brush_length = self.mc.get_variable("Stator_Lam_Dia") + p1 = Coordinate(0, 0) + p2 = Coordinate(brush_length, 0) + brush1.entities.append(Line(p2, p1)) + + brush1.entities.append(Arc(p1, p2, centre=Coordinate(brush_length / 2, 1))) + valid_regions_lower = self.mc.subtract_region(node, brush1) + + # Case where there is no lower intersection + if (len(valid_regions_lower) == 1) and (valid_regions_lower[0].entities == node.entities): + # now perform the upper check + # brush3 used to find the valid portion just below duplication angle + brush3 = Region(region_type=RegionType.airgap) + p1 = Coordinate(0, 0) + p2 = Coordinate.from_polar_coords(brush_length, duplication_angle) + brush3.entities.append(Line(p1, p2)) + brush3.entities.append(Arc(p2, p1, radius=brush_length / 2)) + valid_regions_upper = self.mc.subtract_region(node, brush3) + + # Case where no slicing necessary + if (len(valid_regions_upper) == 1) and ( + valid_regions_upper[0].entities == node.entities + ): + return False + # Case where upper slicing necessary + else: + for i, new_valid_region in enumerate(valid_regions_upper): + new_valid_region.name += f"_{i + 1}" + self.add_node(new_valid_region, parent=node.parent) + # now perform the upper check + # brush4 used to find the invalid portion just above duplication angle + brush4 = Region(region_type=RegionType.airgap) + p1 = Coordinate(0, 0) + p2 = Coordinate.from_polar_coords(brush_length, duplication_angle) + brush4.entities.append(Line(p2, p1)) + brush4.entities.append(Arc(p1, p2, radius=brush_length / 2)) + invalid_regions_upper = self.mc.subtract_region(node, brush4) + for i, new_lower_valid_region in enumerate(invalid_regions_upper): + new_lower_valid_region.rotate(Coordinate(0, 0), -duplication_angle) + new_lower_valid_region.name = new_lower_valid_region.name[0 : name_length + 1] + new_lower_valid_region.name += f"_{i + len(valid_regions_upper) + 1}" + # Linked regions currently only guaranteed to work if only one new region is + # formed at top and bottom; will change once regions can be multiply linked. + new_lower_valid_region.linked_region = valid_regions_upper[i] + valid_regions_upper[i].linked_region = new_lower_valid_region + self.add_node(new_lower_valid_region, parent=node.parent) + self.remove_node(node) + return True + # Case where lower slicing necessary + else: + # first, handle the valid regions returned + for i, new_valid_region in enumerate(valid_regions_lower): + new_valid_region.name += f"_{i+1}" + self.add_node(new_valid_region, parent=node.parent) + + # brush2 used to find the invalid portion just below angle 0 + brush2 = Region(region_type=RegionType.airgap) + p1 = Coordinate(0, 0) + p2 = Coordinate(brush_length, 0) + brush2.entities.append(Line(p1, p2)) + brush2.entities.append(Arc(p2, p1, centre=Coordinate(brush_length / 2, -1))) + # Upper in this case referring to the fact that this region will + # form the upper half of the ellipse. + # It will be below the other half in terms of relative positioning + invalid_regions_lower = self.mc.subtract_region(node, brush2) + for i, new_upper_valid_region in enumerate(invalid_regions_lower): + new_upper_valid_region.rotate(Coordinate(0, 0), duplication_angle) + new_upper_valid_region.name = new_upper_valid_region.name[0 : name_length + 1] + new_upper_valid_region.name += f"_{i + len(valid_regions_lower) + 1}" + # Linked regions currently only guaranteed to work if only one new region is + # formed at top and bottom; will change once regions can be multiply linked. + new_upper_valid_region.linked_region = valid_regions_lower[i] + valid_regions_lower[i].linked_region = new_upper_valid_region + self.add_node(new_upper_valid_region, parent=node.parent) + self.remove_node(node) + return True + + def add_node(self, region, key=None, parent=None, children=None): + """Add node to tree. + + Note that any children specified will be 'reassigned' to the added node, with no + connection to their previous parent. + + Parameters + ---------- + region: ansys.motorcad.core.geometry.Region + Region to convert and add to tree + key: str + Key to be used for dict + parent: GeometryNode or str + Parent object or parent key (must be already within tree) + children: list + List of children objects or children keys (must be already within tree) + + """ + if not isinstance(region, GeometryNode): + region.__class__ = GeometryNode + + if key is None: + region.key = region.name + else: + region.key = key + + # Make certain any nodes being replaced are properly removed + try: + self.remove_node(region.key) + except KeyError: + pass + + if children is None: + region.children = list() + else: + if all(isinstance(child, Region) for child in children): + region.children = children + elif all(isinstance(child, str) for child in children): + direct_children = list(self.get_node(child) for child in children) + region.children = direct_children + else: + raise TypeError("Children must be a GeometryNode or str") + # Essentially, slotting the given node in between the given parent and children + # Children are removed from their old spot and placed in the new one + # Children's children become assigned to child's old parent + for child in region.children: + self.remove_node(child) + child.parent = region + child.children = list() + self[child.key] = child + + if parent is None: + region.parent = self["root"] + self["root"].children.append(region) + else: + if isinstance(parent, GeometryNode): + region.parent = parent + parent.children.append(region) + elif isinstance(parent, str): + region.parent = self.get_node(parent) + self[parent].children.append(region) + else: + raise TypeError("Parent must be a GeometryNode or str") + region._motorcad_instance = self._motorcad_instance + self[region.key] = region + + def remove_node(self, node): + """Remove Node from tree, attach children of removed node to parent.""" + if type(node) is str: + node = self.get_node(node) + for child in node.children: + child.parent = node.parent + node.parent.children.append(child) + node.parent.children.remove(node) + self.pop(node.key) + + def remove_branch(self, node): + """Remove Node and all descendants from tree.""" + if type(node) == str: + node = self.get_node(node) + + # Recursive inner function to find and remove all descendants + def dive(node): + for child in node.children: + dive(child) + self.pop(node.key) + + dive(node) + node.parent.children.remove(node) + + @property + def lowercase_keys(self): + """Return a dict of lowercase keys and their corresponding real keys.""" + return dict((node.key.lower(), node.key) for node in self) + + @property + def start(self): + """Return the start of the tree.""" + # Find starting point + for node in self.values(): + if node.parent is None: + start = node + break + else: + try: + self[node.parent.key] + except KeyError: + start = node + return start + + +class GeometryNode(Region): + """Subclass of Region used for entries in GeometryTree. + + Nodes should not have a parent or children unless they are part of a tree. + """ + + def __init__(self, region_type=RegionType.adaptive): + """Initialize the geometry node. + + Parent and children are defined when the node is added to a tree. + + Parameters + ---------- + region_type: RegionType + """ + super().__init__(region_type=region_type) + self.children = list() + self.parent = None + self.key = None + + def __repr__(self): + """Return string representation of GeometryNode.""" + try: + return self.key + except AttributeError: + return self.name + + @classmethod + def from_json(cls, node_json, parent, mc): + """Create a GeometryNode from JSON data. + + Parameters + ---------- + node_json: dict + parent: GeometryNode + mc: Motorcad + + Returns + ------- + GeometryNode + """ + if node_json["name_unique"] == "root": + new_region = GeometryNode(region_type=RegionType.airgap) + new_region.name = "root" + new_region.key = "root" + + else: + new_region = Region._from_json(node_json) + new_region.__class__ = GeometryNode + new_region.parent = parent + new_region.children = list() + parent.children.append(new_region) + new_region.key = node_json["name_unique"] + + new_region._motorcad_instance = mc + return new_region + + @property + def depth(self): + """Depth of node.""" + depth = 0 + node = self + + while True: + if node.key == "root": + break + depth += 1 + node = node.parent + + return depth + + @property + def parent(self): + """Get or set parent region. + + Returns + ------- + ansys.motorcad.core.geometry.Region + """ + return self._parent + + @parent.setter + def parent(self, parent): + self._parent = parent + + @property + def children(self): + """Get or set parent region. + + Returns + ------- + list of ansys.motorcad.core.geometry.Region + list of Motor-CAD region object + """ + return self._children + + @children.setter + def children(self, children): + self._children = children + + @property + def child_keys(self): + """Get list of keys corresponding to child nodes. + + Returns + ------- + list of str + """ + return list(child.key for child in self.children) + + @property + def parent_key(self): + """Get key corresponding to parent node. + + Returns + ------- + str + """ + if self.parent is None: + return "" + else: + return self.parent.key + + @property + def child_names(self): + """Get list of names corresponding to child nodes. + + Returns + ------- + list of str + """ + return list(child.name for child in self.children) + + @property + def parent_name(self): + """Get name corresponding to parent node. + + Returns + ------- + str + """ + if self.parent is None: + return "" + else: + return self.parent.name diff --git a/tests/test_geometry.py b/tests/test_geometry.py index 3e242d92e..d5bb4b510 100644 --- a/tests/test_geometry.py +++ b/tests/test_geometry.py @@ -653,6 +653,12 @@ def test_line_length(): assert line.length == sqrt(2) +def test_line_get_coordinate_distance(): + line = geometry.Line(geometry.Coordinate(0, 0), geometry.Coordinate(0, 2)) + point = Coordinate(1, 1) + assert line.get_coordinate_distance(point) == 1 + + def test_arc_get_coordinate_from_fractional_distance(): arc = geometry.Arc( geometry.Coordinate(-1, 0), geometry.Coordinate(1, 0), geometry.Coordinate(0, 0), 1 diff --git a/tests/test_geometry_tree.py b/tests/test_geometry_tree.py new file mode 100644 index 000000000..2b4b1e183 --- /dev/null +++ b/tests/test_geometry_tree.py @@ -0,0 +1,402 @@ +# Copyright (C) 2022 - 2025 ANSYS, Inc. and/or its affiliates. +# SPDX-License-Identifier: MIT +# +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +from copy import copy, deepcopy + +import pytest + +from ansys.motorcad.core.geometry import Coordinate, Line, Region, RegionType +from ansys.motorcad.core.methods.geometry_tree import GeometryNode, GeometryTree + + +@pytest.fixture(scope="session") +def sample_tree(mc): + mc.reset_adaptive_geometry() + return mc.get_geometry_tree() + + +@pytest.fixture(scope="function") +def basic_tree(): + """Return a simple GeometryTree for the purposes of testing""" + p1 = Coordinate(25, 2.5) + p2 = Coordinate(25, -2.5) + p3 = Coordinate(29.330127018922, 0) + line1 = Line(p1, p2) + line2 = Line(p2, p3) + line3 = Line(p3, p1) + triangle = Region(region_type=RegionType.airgap) + triangle.__class__ = GeometryNode + triangle.entities.append(line1) + triangle.entities.append(line2) + triangle.entities.append(line3) + triangle.name = "Triangle" + triangle.key = "Triangle" + triangle.children = [] + triangle.duplications = 8 + + tree = GeometryTree() + + triangle.parent = tree["root"] + tree["root"].children.append(triangle) + tree["Triangle"] = triangle + + return tree + + +@pytest.fixture(scope="function") +def split_tree(): + """Return a tree with a duct split across a duplication angle""" + p3 = Coordinate(25, 0) + p2 = Coordinate(25, 2.5) + p1 = Coordinate(29.330127018922, 0) + + line1 = Line(p1, p2) + line2 = Line(p2, p3) + line3 = Line(p3, p1) + triangle1 = Region(region_type=RegionType.airgap) + triangle1.__class__ = GeometryNode + triangle1.entities.append(line1) + triangle1.entities.append(line2) + triangle1.entities.append(line3) + triangle1.name = "Triangle_1" + triangle1.key = "Triangle_1" + triangle1.children = [] + triangle1.duplications = 8 + + p4 = Coordinate(25, -2.5) + line1 = Line(p1, p3) + line2 = Line(p3, p4) + line3 = Line(p4, p1) + triangle2 = Region(region_type=RegionType.airgap) + triangle2.__class__ = GeometryNode + triangle2.entities.append(line1) + triangle2.entities.append(line2) + triangle2.entities.append(line3) + triangle2.name = "Triangle_2" + triangle2.key = "Triangle_2" + triangle2.children = [] + triangle2.duplications = 8 + triangle2.rotate(Coordinate(0, 0), 360 / 8) + + test_tree = GeometryTree() + + triangle1.parent = test_tree["root"] + test_tree["root"].children.append(triangle1) + test_tree["Triangle_1"] = triangle1 + + triangle2.parent = test_tree["root"] + test_tree["root"].children.append(triangle2) + test_tree["Triangle_2"] = triangle2 + + return test_tree + + +def test_get_tree(sample_tree): + node_keys = set(node.key for node in sample_tree) + # Check each item is listed only once among all children + # Check also that each item is the child of something, except root + valid = True + for node in sample_tree.values(): + for child in node.children: + try: + node_keys.remove(child.key) + except KeyError: + valid = False + assert node == child.parent + assert node_keys == {"root"} + assert valid + + +def test_get_node(sample_tree): + assert sample_tree.get_node("rotor") == sample_tree["Rotor"] + + with pytest.raises(TypeError) as e_info: + sample_tree.get_node(5) + + assert "key must be a string or GeometryNode" in str(e_info.value) + + +def test_get_region_type(sample_tree): + nodes = sample_tree.get_nodes_from_type("Magnet") + + for node in nodes: + assert node.region_type.value == "Magnet" + + +def test_tostring(sample_tree): + # Test that all nodes are, at the least, present in the string representation + string_repr = str(sample_tree) + for node in sample_tree: + assert node.key in string_repr + + +def test_add_node(basic_tree): + # Tests the basic functionality of adding a node + new_node = GeometryNode() + new_node.parent = basic_tree["root"] + new_node.name = "node" + new_node.key = "node" + basic_tree.get_node("root").children.append(new_node) + basic_tree["node"] = new_node + + function_tree = deepcopy(basic_tree) + new_node2 = GeometryNode() + new_node2.name = "node" + function_tree.add_node(new_node2, parent=function_tree["root"]) + + assert basic_tree == function_tree + + +def test_get_subtree(basic_tree): + # Test fetching subtrees + test_tree = deepcopy(basic_tree) + assert test_tree.get_subtree("root") == basic_tree + + test_tree.pop("root") + function_tree = basic_tree.get_subtree("Triangle") + assert test_tree == function_tree + + +def test_add_node_with_children(basic_tree): + # Tests the parent and child reassignment performed when including those values + test_tree = deepcopy(basic_tree) + new_node = GeometryNode() + new_node.parent = test_tree["root"] + new_node.children.append(test_tree["Triangle"]) + new_node.name = "node" + new_node.key = "node" + test_tree["root"].children.remove(test_tree["Triangle"]) + test_tree["root"].children.append(new_node) + test_tree["node"] = new_node + test_tree["Triangle"].parent = new_node + + function_tree = deepcopy(basic_tree) + new_node2 = Region() + new_node2.name = "node" + function_tree.add_node(new_node2, parent="root", children=["Triangle"]) + + assert test_tree == function_tree + + +def test_add_node_with_children_2(basic_tree): + # Same test as above, but testing different mode of function input + test_tree = deepcopy(basic_tree) + new_node = GeometryNode() + new_node.parent = test_tree["root"] + new_node.children.append(test_tree["Triangle"]) + new_node.name = "node" + new_node.key = "node1" + test_tree["root"].children.remove(test_tree["Triangle"]) + test_tree["root"].children.append(new_node) + test_tree["node1"] = new_node + test_tree["Triangle"].parent = new_node + + function_tree = deepcopy(basic_tree) + new_node2 = Region() + new_node2.name = "node" + function_tree.add_node( + new_node2, parent=function_tree["root"], children=[function_tree["Triangle"]], key="node1" + ) + + assert test_tree == function_tree + + +def test_add_node_errors(basic_tree): + new_node2 = Region() + new_node2.name = "node" + + with pytest.raises(TypeError, match="Parent must be a GeometryNode or str"): + basic_tree.add_node(new_node2, parent=0) + + with pytest.raises(TypeError, match="Children must be a GeometryNode or str"): + basic_tree.add_node(new_node2, children=[0, "root"]) + + +def test_remove_node(basic_tree): + # Test the basic functionality of removing a node + test_tree = deepcopy(basic_tree) + + function_tree = deepcopy(basic_tree) + new_node2 = GeometryNode() + new_node2.name = "node" + function_tree.add_node(new_node2, children=["Triangle"]) + function_tree.remove_node(new_node2) + assert test_tree == function_tree + + +def test_equality_1(basic_tree): + # Test trees with different sizes are detected + test_tree = deepcopy(basic_tree) + test_tree["root"].children.remove(test_tree["Triangle"]) + test_tree.pop("Triangle") + assert test_tree != basic_tree + + +def test_equality_2(basic_tree): + # Test trees with the same nodes that only differ in structure are detected + test_tree1 = deepcopy(basic_tree) + new_node1 = GeometryNode() + new_node1.name = "node" + test_tree1.add_node(new_node1, parent="root", children=["Triangle"]) + + test_tree2 = deepcopy(basic_tree) + new_node2 = GeometryNode() + new_node2.name = "node" + test_tree2.add_node(new_node2, parent=test_tree2["Triangle"]) + + assert test_tree2 != test_tree1 + + +def test_equality_3(basic_tree): + # Further test that similar but distinct structures are detected + test_tree1 = deepcopy(basic_tree) + new_node1 = GeometryNode() + new_node1.name = "node1" + test_tree1.add_node(new_node1, parent="root") + new_node2 = GeometryNode() + new_node2.name = "node2" + test_tree1.add_node(new_node2, parent="node1", children=["Triangle"]) + + test_tree2 = deepcopy(basic_tree) + new_node3 = GeometryNode() + new_node3.name = "node1" + test_tree2.add_node(new_node3, parent="root", children=["Triangle"]) + new_node4 = GeometryNode() + new_node4.name = "node2" + test_tree2.add_node(new_node4, parent="Triangle") + + assert test_tree1 != test_tree2 + + +def test_equality_4(basic_tree): + # Test that trees with the same structure and names, but different geometries are detected + test_tree = deepcopy(basic_tree) + test_tree["Triangle"].entities.append(Line(Coordinate(0, 0), Coordinate(-1, 0))) + assert test_tree != basic_tree + + +def test_remove_branch(basic_tree): + # Tests the basic functionality of removing a branch + test_tree = deepcopy(basic_tree) + test_tree.remove_node("Triangle") + + function_tree1 = deepcopy(basic_tree) + new_node1 = GeometryNode() + new_node1.name = "node" + function_tree1.add_node(new_node1, parent=function_tree1["root"], children=["Triangle"]) + function_tree1.remove_branch(new_node1) + + function_tree2 = deepcopy(basic_tree) + new_node2 = GeometryNode() + new_node2.name = "node" + function_tree2.add_node(new_node2, parent=function_tree2["root"], children=["Triangle"]) + function_tree2.remove_branch("node") + + assert test_tree == function_tree1 + assert test_tree == function_tree2 + + +def test_remove_branch2(basic_tree): + # Same test, slightly different function input + test_tree = deepcopy(basic_tree) + test_tree.remove_node(test_tree["Triangle"]) + + function_tree = deepcopy(basic_tree) + new_node = GeometryNode() + new_node.name = "node" + function_tree.add_node(new_node, children=["Triangle"]) + function_tree.remove_branch(function_tree["node"]) + assert test_tree == function_tree + + +def test_get_parent(basic_tree): + assert basic_tree["root"] == basic_tree["Triangle"].parent + + assert basic_tree["root"].key == basic_tree["Triangle"].parent_key + + assert basic_tree["root"].name == basic_tree["Triangle"].parent_name + + assert basic_tree["root"].parent_name == "" + assert basic_tree["root"].parent_key == "" + + +def test_get_children(basic_tree): + assert basic_tree["root"].children == [basic_tree["Triangle"]] + + assert basic_tree["root"].child_names == ["Triangle"] + + assert basic_tree["root"].child_keys == ["Triangle"] + + +def test_fix_region1(split_tree, basic_tree, mc): + # Test that a region is correctly fixed when it crosses the lower boundary + + basic_tree.mc = mc + basic_tree.fix_duct_geometry("Triangle") + + assert split_tree == basic_tree + + +def test_fix_region2(basic_tree, split_tree, mc): + # Test that a region is correctly fixed when it crosses the upper boundary + + basic_tree.mc = mc + + basic_tree["Triangle"].rotate(Coordinate(0, 0), 45) + basic_tree.fix_duct_geometry("Triangle") + + # Labeling is slightly different, so the split tree must be updated + node1 = copy(split_tree["Triangle_1"]) + node2 = copy(split_tree["Triangle_2"]) + + node1.name = "Triangle_2" + node1.key = "Triangle_2" + split_tree.add_node(node1) + + node2.name = "Triangle_1" + node2.key = "Triangle_1" + split_tree.add_node(node2) + + assert split_tree == basic_tree + + +def test_fix_region3(basic_tree, mc): + # Test that a region is unaffected when already valid + + basic_tree["Triangle"].rotate(Coordinate(0, 0), 22.5) + + function_tree = deepcopy(basic_tree) + function_tree.mc = mc + function_tree.fix_duct_geometry("Triangle") + + assert basic_tree == function_tree + + +def test_set_tree(mc, sample_tree): + # Test that tree is correctly set after modification + + sample_tree.remove_node("Stator") + mc.set_geometry_tree(sample_tree) + new_tree = mc.get_geometry_tree() + + assert new_tree == sample_tree