diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 88aef6d33..3f059e8ad 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -1178,6 +1178,8 @@ def get_subgraph( MATCH(center: Memory)-[r * 1..{depth}]->(neighbor:Memory) WHERE center.id = '{center_id}' + AND center.status = '{center_status}' + AND center.user_name = '{user_name}' RETURN collect(DISTINCT center), collect(DISTINCT @@ -1255,7 +1257,9 @@ def get_subgraph( } ) - return {"core_node": core_node, "neighbors": neighbors, "edges": edges} + return self._convert_graph_edges( + {"core_node": core_node, "neighbors": neighbors, "edges": edges} + ) except Exception as e: logger.error(f"Failed to get subgraph: {e}", exc_info=True) @@ -2839,3 +2843,25 @@ def get_edges( except Exception as e: logger.error(f"Failed to get edges: {e}", exc_info=True) return [] + + def _convert_graph_edges(self, core_node: dict) -> dict: + import copy + + data = copy.deepcopy(core_node) + id_map = {} + core_node = data.get("core_node", {}) + core_meta = core_node.get("metadata", {}) + if "graph_id" in core_meta and "id" in core_node: + id_map[core_meta["graph_id"]] = core_node["id"] + for neighbor in data.get("neighbors", []): + n_meta = neighbor.get("metadata", {}) + if "graph_id" in n_meta and "id" in neighbor: + id_map[n_meta["graph_id"]] = neighbor["id"] + for edge in data.get("edges", []): + src = edge.get("source") + tgt = edge.get("target") + if src in id_map: + edge["source"] = id_map[src] + if tgt in id_map: + edge["target"] = id_map[tgt] + return data